import sys
# sys.path.append('..')
import os
import random
import numpy as np
from causal_discovery_algs import LearnStructICD, LearnStructFCI
from causal_discovery_utils.cond_indep_tests import CondIndepParCorr  # import a CI test that estimates partial correlation
from experiment_utils.synthetic_graphs import create_random_dag_with_latents, sample_data_from_dag
from causal_discovery_utils.performance_measures import calc_structural_accuracy_pag, find_true_pag
from matplotlib import pyplot as plt
from itertools import combinations

def count_precision_recall_f1(tp, fp, fn):
    # Precision
    if tp + fp == 0:
        precision = None
    else:
        precision = float(tp) / (tp + fp)

    # Recall
    if tp + fn == 0:
        recall = None
    else:
        recall = float(tp) / (tp + fn)

    # F1 score
    if precision is None or recall is None:
        f1 = None
    elif precision == 0 or recall == 0:
        f1 = 0.0
    else:
        f1 = float(2 * precision * recall) / (precision + recall)
    return precision, recall, f1


def count_dag_accuracy(B_bin_true, B_bin_est):
    d = B_bin_true.shape[0]
    # linear index of nonzeros
    pred = np.flatnonzero(B_bin_est)
    cond = np.flatnonzero(B_bin_true)
    cond_reversed = np.flatnonzero(B_bin_true.T)
    cond_skeleton = np.concatenate([cond, cond_reversed])
    # true pos
    true_pos = np.intersect1d(pred, cond, assume_unique=True)
    # false pos
    false_pos = np.setdiff1d(pred, cond_skeleton, assume_unique=True)
    # reverse
    extra = np.setdiff1d(pred, cond, assume_unique=True)
    reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)
    # compute ratio
    pred_size = len(pred)
    cond_neg_size = 0.5 * d * (d - 1) - len(cond)
    if pred_size == 0:
        fdr = None
    else:
        fdr = float(len(reverse) + len(false_pos)) / pred_size
    if len(cond) == 0:
        tpr = None
    else:
        tpr = float(len(true_pos)) / len(cond)
    if cond_neg_size == 0:
        fpr = None
    else:
        fpr = float(len(reverse) + len(false_pos)) / cond_neg_size
    # structural hamming distance
    pred_lower = np.flatnonzero(np.tril(B_bin_est + B_bin_est.T))
    cond_lower = np.flatnonzero(np.tril(B_bin_true + B_bin_true.T))
    extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True)
    missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True)
    shd = len(extra_lower) + len(missing_lower) + len(reverse)
    # false neg
    false_neg = np.setdiff1d(cond, true_pos, assume_unique=True)
    precision, recall, f1 = count_precision_recall_f1(tp=len(true_pos),
                                                      fp=len(reverse) + len(false_pos),
                                                      fn=len(false_neg))
    # return {'fdr': fdr, 'tpr': tpr, 'fpr': fpr, 'shd': shd, 'nnz': pred_size, 
    #         'precision': precision, 'recall': recall, 'f1': f1}
    return {'f1': f1,  'precision': precision, 'recall': recall, 'shd': shd}

def regression(X,Y):
    residual_x = X - Y @ np.linalg.inv(Y.T @ Y) @ Y.T @ X
    residual_y = Y - X @ np.linalg.inv(X.T @ X) @ X.T @ Y
    # import pdb
    # pdb.set_trace()
    # print(residual_x.shape)
    return residual_x, residual_y

def get_dag_from_pdag(B_bin_pdag):
    # There is bug for G.to_dag().to_amat() from cd package
    # i.e., the shape of B is not preserved
    # So we need to manually preserve the shape
    B_bin_dag = np.zeros_like(B_bin_pdag)
    if np.all(B_bin_pdag == 0):
        # All entries in B_pdag are zeros
        return B_bin_dag
    else:
        G = cd.PDAG.from_amat(B_bin_pdag)  # return a PDAG with arcs/edges.
        # print(G.to_amat()[0])
        B_bin_sub_dag, nodes = G.to_dag().to_amat() # The key is: to_dag() - converting a PDAG to a DAG using some rules. 
        # print("G:", G.to_dag().to_amat())
        B_bin_dag[np.ix_(nodes, nodes)] = B_bin_sub_dag
        return B_bin_dag

def dag_to_pag(adj, latent, selection):
    d = adj.shape[0]
    for i in range(d):
        for j in range(i):
            if adj[i][j] ==1:
                adj[j][i] = -1
    for la in latent:
        pair = combinations(la,2)
        for p in pair:
            if adj[p[0]][p[1]] == 0 & adj[p[1]][p[0]] ==0:
                adj[p[0]][p[1]] = 1
                adj[p[1]][p[0]] = 1
            elif adj[p[0]][p[1]] == -1 & adj[p[1]][p[0]] ==1:
                adj[p[0]][p[1]] = 2
                adj[p[1]][p[0]] = 1
            elif adj[p[0]][p[1]] == 1 & adj[p[1]][p[0]] == -1:
                adj[p[0]][p[1]] = 1
                adj[p[1]][p[0]] = 2
    for se in selection:
        if adj[se[0]][se[1]] == 0 & adj[se[1]][se[0]] ==0:
            adj[se[0]][se[1]] = -1
            adj[se[1]][se[0]] = -1
        elif adj[se[0]][se[1]] == -1 & adj[se[1]][se[0]] ==1:
            adj[se[0]][se[1]] = -1
            adj[se[1]][se[0]] = 2
        elif adj[se[0]][se[1]] == 1 & adj[se[1]][se[0]] == -1:
            adj[se[0]][se[1]] = 2
            adj[se[1]][se[0]] = -1
    return adj

def dic_to_adj(dic,num_nodes):
    adj = np.zeros((num_nodes,num_nodes))
    for key, value in dic.items():
        for key1, value1 in value.items():
            if len(value1) == 0:
                continue
            else:
                if key1 == '<--':
                    for j in value1:
                        adj[key][j] = 1
                if key1 == '---':
                    for j in value1:
                        adj[key][j] = -1
                if key1 == 'o--':
                    for j in value1:
                        adj[key][j] = 2
    return adj


interven = 'hard'
variables = [6,7,8]
sample = [[11, 20, 44, 47, 48, 49, 56, 82, 83, 87, 97, 104, 111, 112, 121, 132, 139, 142, 147, 151],
[11, 13, 14, 35, 67, 68, 77, 87, 99, 121, 133, 161, 173, 186, 189, 192, 197, 208, 225, 230],
[12, 52, 56, 72, 78, 91, 102, 117, 131, 138, 160, 162, 163, 176, 208, 214, 254, 261, 289, 344]]
# sample = [[0, 16, 21, 22, 23, 35, 49, 91, 109, 112, 119, 144, 145, 161, 162, 164, 171, 175, 182, 193],
# [ 0, 9, 19, 22, 53, 56, 106, 111, 120, 154, 155, 160, 177, 208, 213, 231, 245, 265, 266, 268],
# [4, 9, 12, 34, 40, 47, 60, 66, 110, 112, 121, 142, 148, 149, 164, 177, 244, 250, 279, 282]]
for j in range(len(variables)):
    acc_dag = []
    f1_dag = []
    recall_dag = []
    shd_dag = []
    d = variables[j]
    for i in sample[j]:
        file_path = f'/home/gongxu.luo/L_S/L_S/latent_with_selection/General_G/{interven}/v_{d}/sample_{i}'
        if not os.path.exists(file_path):
            continue
        data = np.load(os.path.join(file_path, f'sample_{interven}.npz'), allow_pickle=True)
        obs = data['obs']
        selecion = data['selection']
        latent = data['latent']
        gt_adj = data['dag']
        PAG = dag_to_pag(gt_adj,latent,selecion)
        print(i)
        nodes_set = set([i for i in range(obs.shape[1])])
        par_corr_test = CondIndepParCorr(dataset=obs, threshold=0.01)  # CI test with the given significance level
        icd = LearnStructICD(nodes_set, par_corr_test)  # instantiate an ICD learner
        icd.learn_structure()  # learn the causal graph
        # print(type(icd.graph._graph))
        predict = dic_to_adj(icd.graph._graph, obs.shape[1])
                    
        ret_dire = count_dag_accuracy(PAG, predict)
        # print("Directions 1 by CausalDAG: ", ret_dire)
        if ret_dire['f1'] == None:
            ret_dire['f1'] = 0
        if ret_dire['recall'] == None:
            ret_dire['recall'] = 0
        if ret_dire['precision'] == None:
            ret_dire['precision'] = 0
        f1_dag.append(ret_dire['f1'])
        recall_dag.append(ret_dire['recall'])
        acc_dag.append(ret_dire['precision'])
        shd_dag.append(ret_dire['shd'])


    print(f'the average accuracy of dag is {np.mean(acc_dag)}, variance is {np.var(acc_dag)}')
    print(f'the average accuracy of recall of dag is {np.mean(recall_dag)}, variance is {np.var(recall_dag)}')
    print(f'the average f1 score of dag is {np.mean(f1_dag)}, variance is {np.var(f1_dag)}')
    print(f'the average shd of dag is {np.mean(shd_dag)}, variance is {np.var(shd_dag)}')


