import numpy as np
from experiments import utils
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import ListedColormap
from matplotlib.patches import Rectangle
import pandas as pd
import sklearn.metrics



def rc_approximation(method, X, W_est, C_true, trans_clos='FW', epsilon=0.1):
    '''
    naive implementation of approximating C in the cases:
    X = (C + N)(I + \bar(W))
    and 
    C(I + \bar(W)) + N

    returns nan or spectral nmse
    '''
    if method not in ['sparserc', 'notears', 'golem']:
        return float("nan"), float("nan"), float("nan") # only top-performing algorithms
    else:
        d = W_est.shape[0]
        inverse_refl_trans_clos = np.eye(d) - W_est if trans_clos == 'FW' else np.linalg.inv(np.eye(d) + W_est)
        C_est = X @ inverse_refl_trans_clos
        spectral_nmse = np.linalg.norm(C_est - C_true) / np.linalg.norm(C_true)

        non_zero_est = np.where(C_est > epsilon * np.max(C_est), 1, 0)
        non_zero_true = np.where(C_true > 0, 1, 0)
        spectral_support_tpr = np.sum(non_zero_est * non_zero_true) / np.sum(non_zero_true)

        zero_true = np.where(C_true > 0, 0, 1)
        spectral_support_fpr = np.sum(non_zero_est * zero_true) / np.sum(zero_true)

        visualize_spectra(C_true, non_zero_true, C_est, non_zero_est, method)

        tp = np.sum(non_zero_est * non_zero_true)
        p = np.sum(non_zero_true)
        fp = np.sum(non_zero_est * zero_true)
        n = np.sum(zero_true)
        print("TP = {}, P = {}, FP = {}, N = {}".format(tp, p, fp, n))

    return spectral_nmse, spectral_support_tpr, spectral_support_fpr


def rct_approximation(method, X, T, W_est, C_true=None, epsilon=0.1):
    '''
    Approximating the root causes C based 
    C ~ X - XA, where A is the block toeplitz matrix.
    '''
    d = W_est.shape[0]
    p = W_est.shape[1] // d - 1

    flag=False
    if len(X.shape) > 2: 
        flag=True
        n = X.shape[0]
        X = X.reshape((n, d * T))

    if X.shape[1] != d * T:
        a, _ = X.shape
        X = X[:int(a / T) * T, :]
        X = X.reshape((int(a / T), d * T))

    # block_matrix = utils.block_toeplitz(W_est, T=t)
    # dT = block_matrix.shape[0]

    W_est_list = [W_est[:, i * d : (i + 1) * d] for i in range(p + 1)]
    B = np.concatenate(W_est_list[::-1], axis=0) # B = [B_p
                                                #      B_{p-1}
                                                #      ...
                                                #      B_2
                                                #      B_1
                                                #      B_0  ]
    # # estimating the root causes
    # inverse_refl_trans_clos = np.eye(dT) - block_matrix if trans_clos == 'FW' else np.linalg.inv(np.eye(dT) + W_est)
    # C_est = X @ inverse_refl_trans_clos
    C_est = np.zeros(X.shape)
    for t in range(T - 1):
        if t < p:
            Y = X[:, 0 : (t + 1) * d] @ B[- ((t + 1) * d):, :]  # y = [x(0) ... x(t-1) x(t)] B 
        else:
            Y = X[:, (t - p) * d : (t + 1) * d] @ B # y = [x(t-p) ... x(t)] B 

        C_est[:, t * d : (t + 1) * d] = X[:, t * d : (t + 1) * d] - Y # c[t] = x[t] - x[t]A - x[t-1]B_1 -...-x[t-p]B_p

    if flag: 
        C_est = C_est.reshape(n, T, d)

    # if method not in ['dagTFRC', 'sparserc', 'varlingam', 'd_varlingam'] only top-performing algorithms
    if C_true is None:
        return C_est, float("nan"), float("nan"), float("nan") # in real datasets the root cause ground truth is unknown
    
    else:
        # evaluating the estimation
        c_nmse = np.linalg.norm(C_est - C_true) / np.linalg.norm(C_true)

        ones_est = np.where(C_est > epsilon, 1, 0)
        ones_true = np.where(C_true > 0, 1, 0)
        c_shd = (ones_est * (1 - ones_true) + (1 - ones_est) * (ones_true)).sum() # total disagreement
        c_total = ones_true.sum() # total root causes

        # c_auroc = sklearn.metrics.roc_auc_score(ones_true.flatten(), ones_est.flatten())
        # c_prec = sklearn.metrics.precision_score(ones_true.flatten(), ones_est.flatten()) # tp / (tp + fp) (1-FDR)

        # non_zero_est = np.where(C_est > epsilon * np.max(C_est), 1, 0)
        # non_zero_true = np.where(C_true > 0, 1, 0)
        # spectral_support_tpr = np.sum(non_zero_est * non_zero_true) / np.sum(non_zero_true)

        # zero_true = np.where(C_true > 0, 0, 1)
        # spectral_support_fpr = np.sum(non_zero_est * zero_true) / np.sum(zero_true)

        # visualize_spectra(C_true, non_zero_true, C_est, non_zero_est, method)

        # tp = np.sum(non_zero_est * non_zero_true)
        # p = np.sum(non_zero_true)
        # fp = np.sum(non_zero_est * zero_true)
        # n = np.sum(zero_true)
        # print("TP = {}, P = {}, FP = {}, N = {}".format(tp, p, fp, n))
        
        return C_est, c_nmse, c_total, c_shd



def visualize_spectra(C_true, non_zero_true, C_est, non_zero_est, method):
    n, d = C_est.shape
    
    with plt.style.context('ggplot'):
        plt.rcParams['font.family'] = 'gillsans'
        plt.rcParams['xtick.color'] = 'black'
        plt.rcParams['ytick.color'] = 'black'
        
        gray = cm.get_cmap('gray', 4)
        newcolors = gray(np.linspace(0, 1, 4))
        white = np.array([1, 1, 1, 1])
        black = np.array([0, 0, 0, 1])
        red = np.array([1, 0, 0, 1])
        grey = np.array([0.5, 0.5, 0.5, 1])
        newcolors[0, :] = white
        newcolors[1, :] = grey
        newcolors[2, :] = red
        newcolors[3, :] = black
        custom_cmp = ListedColormap(newcolors)

        l2 = np.where(non_zero_est != 0, 1, 0)

        common_l2 = non_zero_true * l2
        wrong_l2 = l2 - common_l2
        missed_l2 = non_zero_true - common_l2
        l2 = common_l2 + 0.66 * wrong_l2 + 0.33 * missed_l2

        fig, (ax1, ax2) = plt.subplots(1, 2)

        ax1.imshow(non_zero_true, cmap=custom_cmp)
        ax1.grid(False)
        ax1.add_patch(Rectangle((-0.5,-0.5), n - 0.15, d - 0.15, linewidth=1, edgecolor='black', facecolor='none'))
        ax1.axis('off')
        ax1.set_title('Ground Truth')

        ax2.imshow(l2, cmap=custom_cmp)
        ax2.grid(False)
        ax2.add_patch(Rectangle((-0.5,-0.5), n - 0.15, d - 0.15, linewidth=1, edgecolor='black', facecolor='none'))
        ax2.axis('off')
        ax2.set_title('Estimated')

        fig.suptitle('Root causes')

        plt.savefig('plots/root_causes_{}.pdf'.format(method), dpi=1000)


def count_accuracy(B_true, B_est):
    """Compute various accuracy metrics for B_est.

    true positive = predicted association exists in condition in correct direction
    reverse = predicted association exists in condition in opposite direction
    false positive = predicted association does not exist in condition

    Args:
        B_true (np.ndarray): [d, d] ground truth graph, {0, 1}
        B_est (np.ndarray): [d, d] estimate, {0, 1, -1}, -1 is undirected edge in CPDAG

    Returns:
        fdr: (reverse + false positive) / prediction positive
        tpr: (true positive) / condition positive
        fpr: (reverse + false positive) / condition negative
        shd: undirected extra + undirected missing + reverse
        nnz: prediction positive
    """
    if (B_est == -1).any():  # cpdag
        if not ((B_est == 0) | (B_est == 1) | (B_est == -1)).all():
            raise ValueError('B_est should take value in {0,1,-1}')
        if ((B_est == -1) & (B_est.T == -1)).any():
            raise ValueError('undirected edge should only appear once')
    else:  # dag
        if not ((B_est == 0) | (B_est == 1)).all():
            raise ValueError('B_est should take value in {0,1}')
        # if not utils.is_dag(B_est):
            # raise ValueError('B_est should be a DAG')
            # print('Warning: B_est is not a DAG') # in order also to evaluate algorithms that do not return a DAG
    d = B_true.shape[0]
    # linear index of nonzeros
    pred_und = np.flatnonzero(B_est == -1)
    pred = np.flatnonzero(B_est == 1)
    cond = np.flatnonzero(B_true)
    cond_reversed = np.flatnonzero(B_true.T)
    cond_skeleton = np.concatenate([cond, cond_reversed])
    # true pos
    true_pos = np.intersect1d(pred, cond, assume_unique=True)
    # treat undirected edge favorably
    true_pos_und = np.intersect1d(pred_und, cond_skeleton, assume_unique=False)
    true_pos = np.concatenate([true_pos, true_pos_und])
    # false pos
    false_pos = np.setdiff1d(pred, cond_skeleton, assume_unique=False)
    false_pos_und = np.setdiff1d(pred_und, cond_skeleton, assume_unique=False)
    false_pos = np.concatenate([false_pos, false_pos_und])
    # reverse
    extra = np.setdiff1d(pred, cond, assume_unique=True)
    reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)
    # compute ratio
    pred_size = len(pred) + len(pred_und)
    cond_neg_size = 0.5 * d * (d - 1) - len(cond)
    fdr = float(len(reverse) + len(false_pos)) / max(pred_size, 1)
    tpr = float(len(true_pos)) / max(len(cond), 1)
    fpr = float(len(reverse) + len(false_pos)) / max(cond_neg_size, 1)
    # structural hamming distance
    pred_lower = np.flatnonzero(np.tril(B_est + B_est.T))
    cond_lower = np.flatnonzero(np.tril(B_true + B_true.T))
    extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True)
    missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True)
    shd = len(extra_lower) + len(missing_lower) + len(reverse)
    return {'fdr': fdr, 'tpr': tpr, 'fpr': fpr, 'shd': shd, 'nnz': pred_size}