import numpy as np
def keep_skeleton(dag_batch, aug_dag):
    dag_s1 = aug_dag[0:-1, 0:-1]
    mask = np.logical_or(dag_s1, dag_s1.T)
    dag_masked_batch = np.zeros([dag_batch.shape[0], dag_batch.shape[1]-1, dag_batch.shape[2]-1])
    for dag_index in range(dag_batch.shape[0]):
        dag = dag_batch[dag_index]
        dag = dag[0:-1, 0:-1]
        d = dag.shape[0]
        dag_masked = dag * mask
        # this double for loop can be replaced with matrix op
        for i in range(d):
            for j in range(d):
                if(i>j):
                    if(dag_masked[i, j] == 0 and dag_masked[j, i] == 0):
                        if(dag_s1[i, j] == 1 and dag_s1[j, i] == 0):
                            dag_masked[i, j] = 1
                        if(dag_s1[j, i] == 1 and dag_s1[i, j] == 0):
                            dag_masked[j, i] = 1
                        if (dag_s1[j, i] == 1 and dag_s1[i, j] == 1):
                            dag_masked[i, j] = 1
        dag_masked_batch[dag_index] = dag_masked

    return dag_masked_batch, aug_dag[0:-1, -1]

def keep_skeleton_with_index(dag_batch, aug_dag):
    dag_s1 = aug_dag[0:-1, 0:-1]
    mask = np.logical_or(dag_s1, dag_s1.T)
    dag_masked_batch = np.zeros([dag_batch.shape[0], dag_batch.shape[1], dag_batch.shape[2]])
    for dag_index in range(dag_batch.shape[0]):
        dag = dag_batch[dag_index]
        dag = dag[0:-1, 0:-1]
        d = dag.shape[0]
        dag_masked = dag * mask
        # this double for loop can be replaced with matrix op
        for i in range(d):
            for j in range(d):
                if(i>j):
                    if(dag_masked[i, j] == 0 and dag_masked[j, i] == 0):
                        if(dag_s1[i, j] == 1 and dag_s1[j, i] == 0):
                            dag_masked[i, j] = 1
                        if(dag_s1[j, i] == 1 and dag_s1[i, j] == 0):
                            dag_masked[j, i] = 1
                        if (dag_s1[j, i] == 1 and dag_s1[i, j] == 1):
                            dag_masked[i, j] = 1
        # dag_masked[:, -1] = aug_dag[:-1, -1]
        dag_masked_batch[dag_index][:-1, :-1] = dag_masked
        dag_masked_batch[dag_index][:-1, -1] = aug_dag[:-1, -1]

    return dag_masked_batch, aug_dag[0:-1, -1]