"""
This script loads the 4 nodes MOSFET dataset and runs the CD-NOD
algorithm and reports accuracy metrics.
"""

from causallearn.search.ConstraintBased.CDNOD import cdnod
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import time


def plot_directed_graph(adjacency_matrix):
    # import pygraphviz as pgv
    nNodes=adjacency_matrix.shape[0]
    G=nx.DiGraph()
    G.add_nodes_from(['X'+str(ind) for ind in range(1,nNodes+1)])
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            if adjacency_matrix[ind1,ind2]:
                # print(ind1,ind2)
                G.add_edge('X'+str(ind2+1),'X'+str(ind1+1))
    # explicitly set positions
    pos = {'X'+str(1): (-1, 0.3), 'X'+str(2): (0, 0), 
           'X'+str(3): (1.7, 0.42), 'X'+str(4): (4.5, 0.455), 
           'X'+str(5): (5, -0.06), 'X'+str(6): (8, 0.13)}
    options = {
    "font_size": 26,
    "node_size": 3000,
    "node_color": "white",
    "edgecolors": "black",
    "linewidths": 3,
    "width": 2,
    }

    nx.draw_networkx(G, pos, **options)
    # Set margins for the axes so that nodes aren't clipped
    ax = plt.gca()
    ax.margins(0.20)
    plt.axis("off")
    fig1 = plt.gcf()
    fig1.savefig("Network Reconstruction Random Variable independence model"+".pdf")
    plt.show(block=False)
    plt.pause(0.001)
   

    
# Interpret graph according to the causal-learn documentation online
def retrieve_adj(cg):
    nNodes=cg.G.graph.shape[0]
    Adj=np.zeros([nNodes,nNodes])
    for i in range(nNodes):
        for j in range(nNodes):
            if cg.G.graph[j,i]==1 and cg.G.graph[i,j]==-1:
                Adj[j,i]=1
            elif cg.G.graph[j,i]==-1 and cg.G.graph[i,j]==-1:
                Adj[j,i]=1
                Adj[i,i]=1
            elif cg.G.graph[j,i]==1 and cg.G.graph[i,j]==1:
                Adj[j,i]=1
                Adj[i,i]=1
                
    return Adj

# Find moral graph and skeleton for a given graph given its adjacency matrix
# Input:  (1) adjacency_matrix: Adjacency matrix of directed graph
# Output: (1) MG: Adjacency matrix of the estimated moral graph
#         (2) Skeleton: Adjacency matrix of the estimated skeleton
def find_moral_graph_and_skeleton_from_adjacency_matrix(Adj):
    nNodes=Adj.shape[0]
    Skeleton=(Adj+Adj.T)
    MG=np.zeros([nNodes,nNodes],dtype=bool)
    MG=np.copy(Skeleton)
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            for ind3 in range(nNodes):
                if (Adj[ind1,ind2]==1) and ((Adj[ind1,ind3]==1)) and ind2!=ind3:
                    MG[ind2,ind3]=1
                    MG[ind3,ind2]=1
    return MG,Skeleton

def find_true_Essential_graph(Skeleton, Adj):
    collider_graph = np.copy(Skeleton)
    collider_nodes = []
    for ind1 in range(num_nodes): # ind1 is the collider node index
        for ind2 in range(num_nodes):    
        #ind2 and ind3 are the indices for the parent of colliders
            for ind3 in range(num_nodes):
                if (Adj[ind1,ind2] == 1) and ((Adj[ind1,ind3] == 1)) and Skeleton[ind2,ind3] !=1 and ind2 != ind3 and ind1 != ind3 and ind1 != ind2:
                    # print( ind1," is a collider")
                    collider_graph[ind2,ind1]=0
                    collider_graph[ind3,ind1]=0
                    collider_nodes.append(ind1)                       
    # % Orienting other edges if possible
    true_ess_graph = np.copy(collider_graph)
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i>j and collider_graph[i,j]==1 and collider_graph[j,i]==1:
                if i in collider_nodes:
                    true_ess_graph[i,j] = 0
                if j in collider_nodes:
                    true_ess_graph[j,i] = 0
                    
    return true_ess_graph

def get_confusion_matrix(true_graph,estimated_graph):
    num_nodes = true_graph.shape[0]
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i!=j:
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 1:
                    TP += 1
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 0:
                    FN += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 1:
                    FP += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 0:
                    TN += 1

    return TP, FP, TN, FN

nNodes = 4
num_nodes = nNodes
data = np.loadtxt('FourNodesMOSFETHardwareData1M.txt', dtype='d', delimiter=',')

Adj = np.array([[0, 0, 0, 0],
                [0, 0, 0, 0],
                [1, 1, 0, 0],
                [0, 0, 1, 0]])

MG,Skeleton = find_moral_graph_and_skeleton_from_adjacency_matrix(Adj)

true_ess_graph = find_true_Essential_graph(Skeleton, Adj)
t_indx = np.arange(0,data.shape[0])
c_indx = t_indx.reshape(-1,1)


CS_prev = 0
alpha_best = 0
alpha_vec = np.arange(0.01,0.8,0.01)
for alpha_val in alpha_vec:
    start_time = time.time()
    cg = cdnod(data, c_indx, alpha_val, indep_test='fisherz')
    end_time = time.time()
    est_essential_graph_freq=retrieve_adj(cg)
    est_essential_graph_freq = est_essential_graph_freq[:-1, :-1]
    
    error_freq_PC = sum(sum(est_essential_graph_freq!=true_ess_graph))/sum(sum(true_ess_graph))
    [TP, FP, TN, FN] = get_confusion_matrix(true_ess_graph, est_essential_graph_freq)
    TPR = TP/(TP+FN)
    FPR = FP/(TN+FP)
    CS = TPR - FPR
    computation_time = end_time-start_time
    print("----------------------------------------______")
    print(f"alpha = {alpha_val}")
    print(f"TPR = {TPR}, FPR = {FPR}")
    print(f"Elapsed algorithm run time = {computation_time}")
    precision = TP / (TP + FP)
    recall    = TP / (TP + FN)
    f1        = (2 * precision * recall) /(precision + recall)
    print('F1 Score = ',f1)

    if CS > CS_prev:
        alpha_best = alpha_val
        CS_prev = CS

print(f"Best alpha = {alpha_best}")



# plot_directed_graph(est_essential_graph_freq)