import numpy as np
from causallearn.utils.cit import CIT
from itertools import combinations
from copy import deepcopy
import pydotplus
import os
from causallearn.search.ScoreBased.GES import ges
import causaldag as cd
from scipy.stats import ks_2samp

def count_precision_recall_f1(tp, fp, fn):
    # Precision
    if tp + fp == 0:
        precision = None
    else:
        precision = float(tp) / (tp + fp)

    # Recall
    if tp + fn == 0:
        recall = None
    else:
        recall = float(tp) / (tp + fn)

    # F1 score
    if precision is None or recall is None:
        f1 = None
    elif precision == 0 or recall == 0:
        f1 = 0.0
    else:
        f1 = float(2 * precision * recall) / (precision + recall)
    return precision, recall, f1


def count_dag_accuracy(B_bin_true, B_bin_est):
    d = B_bin_true.shape[0]
    # linear index of nonzeros
    pred = np.flatnonzero(B_bin_est)
    cond = np.flatnonzero(B_bin_true)
    cond_reversed = np.flatnonzero(B_bin_true.T)
    cond_skeleton = np.concatenate([cond, cond_reversed])
    # true pos
    true_pos = np.intersect1d(pred, cond, assume_unique=True)
    # false pos
    false_pos = np.setdiff1d(pred, cond_skeleton, assume_unique=True)
    # reverse
    extra = np.setdiff1d(pred, cond, assume_unique=True)
    reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)
    # compute ratio
    pred_size = len(pred)
    cond_neg_size = 0.5 * d * (d - 1) - len(cond)
    if pred_size == 0:
        fdr = None
    else:
        fdr = float(len(reverse) + len(false_pos)) / pred_size
    if len(cond) == 0:
        tpr = None
    else:
        tpr = float(len(true_pos)) / len(cond)
    if cond_neg_size == 0:
        fpr = None
    else:
        fpr = float(len(reverse) + len(false_pos)) / cond_neg_size
    # structural hamming distance
    pred_lower = np.flatnonzero(np.tril(B_bin_est + B_bin_est.T))
    cond_lower = np.flatnonzero(np.tril(B_bin_true + B_bin_true.T))
    extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True)
    missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True)
    shd = len(extra_lower) + len(missing_lower) + len(reverse)
    # false neg
    false_neg = np.setdiff1d(cond, true_pos, assume_unique=True)
    precision, recall, f1 = count_precision_recall_f1(tp=len(true_pos),
                                                      fp=len(reverse) + len(false_pos),
                                                      fn=len(false_neg))
    # return {'fdr': fdr, 'tpr': tpr, 'fpr': fpr, 'shd': shd, 'nnz': pred_size, 
    #         'precision': precision, 'recall': recall, 'f1': f1}
    return {'f1': f1,  'precision': precision, 'recall': recall, 'shd': shd}

def regression(X,Y):
    residual_x = X - Y @ np.linalg.inv(Y.T @ Y) @ Y.T @ X
    residual_y = Y - X @ np.linalg.inv(X.T @ X) @ X.T @ Y
    # import pdb
    # pdb.set_trace()
    # print(residual_x.shape)
    return residual_x, residual_y

def get_dag_from_pdag(B_bin_pdag):
    # There is bug for G.to_dag().to_amat() from cd package
    # i.e., the shape of B is not preserved
    # So we need to manually preserve the shape
    B_bin_dag = np.zeros_like(B_bin_pdag)
    if np.all(B_bin_pdag == 0):
        # All entries in B_pdag are zeros
        return B_bin_dag
    else:
        G = cd.PDAG.from_amat(B_bin_pdag)  # return a PDAG with arcs/edges.
        # print(G.to_amat()[0])
        B_bin_sub_dag, nodes = G.to_dag().to_amat() # The key is: to_dag() - converting a PDAG to a DAG using some rules. 
        # print("G:", G.to_dag().to_amat())
        B_bin_dag[np.ix_(nodes, nodes)] = B_bin_sub_dag
        return B_bin_dag

interven = 'hard'
# variables = [6,9,12,15,18]
# sample = [[0, 1, 11, 17, 18, 23, 25, 29, 36, 37, 38, 39, 48, 54, 60, 63, 75, 77, 88, 93],
#           [ 10, 49, 58, 93, 129, 132, 198, 205, 207, 209, 215, 242, 245, 272, 283, 288, 310, 312, 320, 335],
#           [ 8, 80, 128, 176, 182, 187, 201, 242, 263, 320, 465, 476, 526, 591, 610, 616, 658, 664, 701, 722],
#           [ 76, 472, 596, 629, 805, 959, 1041, 1049, 1084, 1110, 1160, 1237, 1399, 1403, 1410, 1889],
#           [ 219, 346, 739, 765, 818, 928, 1282, 1753, 1865, 2009, 2337, 2606, 2709, 2758]]

# sample = [[47, 125, 231, 415, 461, 597, 639, 928, 1012, 1027, 1048, 1075, 1144, 1190, 1399, 1586, 1661, 1750, 1927, 1928],
#           [9, 29, 55, 72, 90, 121, 132, 171, 234, 239, 254, 276, 281, 282, 294, 303, 340, 429, 445, 449],
#           [77, 148, 193, 235, 255, 287, 307, 336, 388, 415, 506, 559, 691, 814, 944, 1142, 1348, 1442, 1547, 1589],
#           [23, 61, 315, 463, 869, 1157, 1280, 1325, 1406, 1428, 1481, 1763, 1786, 1796, 1855, 1893],
#           [57, 70, 169, 541, 573, 931, 963, 1226, 1240, 1333, 1378]]
variables = [5,6,7,8,9]
sample = [[ 0, 7, 9, 29, 65, 104, 108, 138, 146, 174, 211, 235, 236, 256, 268, 277, 295, 302, 354, 365],
          [ 108, 119, 132, 145, 177, 246, 252, 280, 292, 309, 329, 341, 373, 416, 449, 475, 573, 590, 641, 659],
          [ 128, 146, 180, 213, 223, 257, 307, 321, 342, 455, 564, 676, 690, 771, 788, 791, 795, 898, 973, 1007],
          [ 209, 232, 495, 685, 909, 943, 966, 1082, 1173, 1237, 1253, 1280, 1382, 1415, 1506, 1520, 1550, 1577, 1687, 1763],
          [ 75, 148, 258, 264, 450, 554, 825, 989, 1053, 1092, 1129, 1176, 1419, 1423, 1930, 1935, 1948, 1999]]
# sample = [[ 216, 219, 246, 252, 257, 258, 265, 277, 287, 301, 327, 329, 344, 361, 395, 403, 411, 417, 423],
#           [ 18, 39, 58, 96, 123, 160, 224, 282, 302, 318, 362, 390, 391, 406, 409, 474, 607, 615, 647, 669],
#           [ 30, 51, 57, 61, 157, 214, 286, 292, 338, 356, 362, 363, 411, 505, 506, 509, 525, 574, 639],
#           [ 12, 49, 63, 146, 220, 249, 300, 412, 445, 458, 528, 671, 692, 730, 809, 928, 938, 1008, 1064],
#           [ 29, 68, 91, 120, 175, 194, 210, 216, 315, 320, 352, 432, 440, 533, 576, 597, 602, 676, 743, 868]]
for j in range(5):
    acc_dag = []
    f1_dag = []
    recall_dag = []
    shd_dag = []
    d = variables[j]
    for i in sample[j]:
        file_path = f'/Users/gongxu.luo/Desktop/gene/cell-cell communication/v_{d}/general/{interven}/sample_{i}'
        if not os.path.exists(file_path):
            continue
        data = np.load(os.path.join(file_path, f'sample_{interven}.npz'), allow_pickle=True)
        obs = data['obs']
        cg = ges(obs,'local_score_CV_general')
        # cg = ges(obs, 'local_score_BIC')
        # est_graph = cg['G'].graph
        # est_dag = get_dag_from_pdag(est_graph)
        # import pdb
        # pdb.set_trace()
        # print(type(cg))
        row = cg['G'].graph.shape[0]
        col = cg['G'].graph.shape[1]
        dag = cg['G'].graph
        for m in range(row):
            for n in range(col):
                if dag[m][n] + dag[n][m] != 0:
                    per_data = data[f'per_{m}']
                    statistic, p_value = ks_2samp(per_data[:,n], obs[:,n])
                    if p_value > 0.05:
                        dag[m][n] = 0
                    else:
                        dag[m][n] = 1
        # for m in range(row):
        #     for n in range(col):
        #         if dag[m][n] + dag[n][m] == 0:
        #             if dag[m][n] == 0 and dag[n][m] ==0:
        #                 continue
        #             elif dag[m][n] == 1:
        #                 dag[m][n] = 0
        #                 dag[n][m] = 1
        #             else:
        #                 dag[m][n] = 1
        #                 dag[n][m] = 0
        #         else:
        #             r_x, r_y = regression(obs[:,m].reshape(-1,1), obs[:,n].reshape(-1,1))
        #             data_cit = np.concatenate((obs[:,m].reshape(-1,1), obs[:,n].reshape(-1,1),r_x,r_y), axis=1)
                    
        #             CIT_obj = CIT(data_cit,'fisherz')
        #             p_x = CIT_obj(1,2)
        #             p_y = CIT_obj(0,3)
        #             if p_x > 0.05 and p_y < 0.05:
        #                 dag[n][m] = 1
        #                 dag[m][n] = 0
        #             elif p_y > 0.05 and p_x < 0.05:
        #                 dag[n][m] = 0
        #                 dag[m][n] = 1
        #             else:
        #                 dag[m][n] = 1
        #                 dag[n][m] = 1
        ret_dire = count_dag_accuracy(data['dag'], dag)
        print("Directions 1 by CausalDAG: ", ret_dire)
        if ret_dire['f1'] == None:
            ret_dire['f1'] = 0
        if ret_dire['recall'] == None:
            ret_dire['recall'] = 0
        if ret_dire['precision'] == None:
            ret_dire['precision'] = 0
        f1_dag.append(ret_dire['f1'])
        recall_dag.append(ret_dire['recall'])
        acc_dag.append(ret_dire['precision'])
        shd_dag.append(ret_dire['shd'])


    print(f'the average accuracy of dag is {np.mean(acc_dag)}, variance is {np.var(acc_dag)}')
    print(f'the average accuracy of recall of dag is {np.mean(recall_dag)}, variance is {np.var(recall_dag)}')
    print(f'the average f1 score of dag is {np.mean(f1_dag)}, variance is {np.var(f1_dag)}')
    print(f'the average shd of dag is {np.mean(shd_dag)}, variance is {np.var(shd_dag)}')


