import igraph as ig
import numpy as np
import time,datetime,os
from rich import print as rprint
import itertools

def is_dag(W):
    G = ig.Graph.Weighted_Adjacency(W.tolist())
    return G.is_dag()


def count_accuracy(B_true, B_est):
    """Compute various accuracy metrics for B_est.

    true positive = predicted association exists in condition in correct direction
    reverse = predicted association exists in condition in opposite direction
    false positive = predicted association does not exist in condition

    Args:
        B_true (np.ndarray): [d, d] ground truth graph, {0, 1}
        B_est (np.ndarray): [d, d] estimate, {0, 1, -1}, -1 is undirected edge in CPDAG

    Returns:
        fdr: (reverse + false positive) / prediction positive
        tpr: (true positive) / condition positive
        fpr: (reverse + false positive) / condition negative
        shd: undirected extra + undirected missing + reverse
        nnz: prediction positive
    """
    if (B_est == -1).any():  # cpdag
        if not ((B_est == 0) | (B_est == 1) | (B_est == -1)).all():
            raise ValueError('B_est should take value in {0,1,-1}')
        if ((B_est == -1) & (B_est.T == -1)).any():
            raise ValueError('undirected edge should only appear once')
    else:  # dag
        if not ((B_est == 0) | (B_est == 1)).all():
            raise ValueError('B_est should take value in {0,1}')
        if not is_dag(B_est):
            #raise ValueError('B_est should be a DAG')
            print('B_est should be a DAG')
    d = B_true.shape[0]
    # linear index of nonzeros
    pred_und = np.flatnonzero(B_est == -1)
    pred = np.flatnonzero(B_est == 1)
    cond = np.flatnonzero(B_true)
    cond_reversed = np.flatnonzero(B_true.T)
    cond_skeleton = np.concatenate([cond, cond_reversed])
    # true pos
    true_pos = np.intersect1d(pred, cond, assume_unique=True)
    # treat undirected edge favorably
    true_pos_und = np.intersect1d(pred_und, cond_skeleton, assume_unique=True)
    true_pos = np.concatenate([true_pos, true_pos_und])
    # false pos
    false_pos = np.setdiff1d(pred, cond_skeleton, assume_unique=True)
    false_pos_und = np.setdiff1d(pred_und, cond_skeleton, assume_unique=True)
    false_pos = np.concatenate([false_pos, false_pos_und])
    # reverse
    extra = np.setdiff1d(pred, cond, assume_unique=True)
    reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)
    # compute ratio
    pred_size = len(pred) + len(pred_und)
    cond_neg_size = 0.5 * d * (d - 1) - len(cond)
    fdr = round(float(len(reverse) + len(false_pos)) / max(pred_size, 1), 4)
    tpr = round(float(len(true_pos)) / max(len(cond), 1), 4)
    fpr = round(float(len(reverse) + len(false_pos)) /
                max(cond_neg_size, 1), 4)
    # structural hamming distance
    pred_lower = np.flatnonzero(np.tril(B_est + B_est.T))
    cond_lower = np.flatnonzero(np.tril(B_true + B_true.T))
    extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True)
    missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True)
    # shd = len(extra_lower) + len(missing_lower) + len(reverse)
    extra = len(extra_lower)
    missing = len(missing_lower)
    reverse = len(reverse)
    shd = extra + missing + reverse
    precision = round(float(len(true_pos)) / max(len(pred), 1), 4)
    recall = round(float(len(true_pos)) / max(len(cond), 1), 4)
    f1 = round(2 * precision * recall / max(precision + recall, 1e-8), 4)
    return {'fdr': fdr, 'tpr': tpr, 'fpr': fpr, "precision": precision, "recall": recall, "f1": f1,
            'extra': extra, 'missing': missing,  'reverse': reverse, 'shd': shd, 'nnz': pred_size}

def numerical_SHD(B_true, B_est):
    numerical_SHD_noextra=sum(sum(abs(B_true- np.where(B_true == 0, 0, B_est))))
    numerical_SHD_noextra=round(numerical_SHD_noextra,4)
    numerical_SHD=sum(sum(abs(B_true-B_est)))
    numerical_SHD=round(numerical_SHD,4)
    return {'numerical_SHD_noextra':numerical_SHD_noextra,'numerical_SHD':numerical_SHD}

def sid(tar,pred):
    try:
        from cdt.metrics import SID
        return {'SID':SID(tar, pred).item()}
    except:
        return {'SID':None}
    
def evaluation(model,true_dag,weight_true_dag,time1,time2,lambda1,lambda2,sigma,args,output_path):
    metric = count_accuracy(true_dag,model.causal_matrix)
    metric.update(numerical_SHD(weight_true_dag,model.weight_causal_matrix))
    metric.update(sid(true_dag,model.causal_matrix))
    metric['time']=round(time2-time1,4)
    metric['lambda1']=round(lambda1,4)
    metric['lambda2']=round(lambda2,4)
    metric['sigma']=round(sigma,4)
    metric['finished']=datetime.datetime.now()
    try:
        metric['p_p']=model.precision
        metric['p_r']=model.recall
        metric['p_f1']=model.f1
    except:
        metric['p_p']=0
        metric['p_r']=0
        metric['p_f1']=0
    rprint(metric)
    
    parameter={'n_nodes':args.n_nodes,'ER':args.ER,'size':args.size, 'graph_type':args.graph_type,'random':args.random, 'method':args.method,'sem_type':args.sem_type,'loss_type':args.loss_type,
               'prior_type':args.prior_type,'proportion':args.proportion,'confidence':args.confidence,'error_prior_proportion':args.error_prior_proportion,'error_prior_type':args.error_prior_type,'alg':args.alg,'adaptive_degree':args.adaptive_degree,'scale':args.scale}
    parameter_values = [str(i) for i in parameter.values()]
    metric_values = [str(i) for i in metric.values()]
    if not os.path.exists(output_path):
        with open(output_path, 'a') as f:
            f.write(','.join(list(parameter.keys()))+',' +
                    ','.join(list(metric.keys()))+'\n')
    with open(output_path, 'a') as f:
        eva_info = ','.join(parameter_values)+','+','.join(metric_values)
        f.write(f'{eva_info}\n')

    label='|'.join(parameter_values)
    np.savetxt(f'out/W_est/{label}.csv', model.weight_causal_matrix, delimiter=',', fmt='%.4f')