"""
This script applies Fisher-Z test based PC algorithm to the 20 nodes network of synamical system and reports the results
"""

from causallearn.search.ConstraintBased.PC import pc
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import time


def plot_directed_graph(adjacency_matrix):
    # import pygraphviz as pgv
    nNodes=adjacency_matrix.shape[0]
    G=nx.DiGraph()
    G.add_nodes_from(['X'+str(ind) for ind in range(1,nNodes+1)])
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            if adjacency_matrix[ind1,ind2]:
                # print(ind1,ind2)
                G.add_edge('X'+str(ind2+1),'X'+str(ind1+1))
    # explicitly set positions
    pos = {'X'+str(1): (-1, 0.3), 'X'+str(2): (0, 0), 
           'X'+str(3): (1.7, 0.42), 'X'+str(4): (4.5, 0.455), 
           'X'+str(5): (5, -0.06), 'X'+str(6): (8, 0.13)}
    options = {
    "font_size": 26,
    "node_size": 3000,
    "node_color": "white",
    "edgecolors": "black",
    "linewidths": 3,
    "width": 2,
    }

    nx.draw_networkx(G, pos, **options)
    # Set margins for the axes so that nodes aren't clipped
    ax = plt.gca()
    ax.margins(0.20)
    plt.axis("off")
    fig1 = plt.gcf()
    fig1.savefig("Network Reconstruction Random Variable independence model"+".pdf")
    plt.show(block=False)
    plt.pause(0.001)
   

    
    
def retrieve_adj(cg):
    nNodes=cg.G.graph.shape[0]
    Adj=np.zeros([nNodes,nNodes])
    for i in range(nNodes):
        for j in range(nNodes):
            if cg.G.graph[j,i]==1 and cg.G.graph[i,j]==-1:
                Adj[j,i]=1
            elif cg.G.graph[j,i]==-1 and cg.G.graph[i,j]==-1:
                Adj[j,i]=1
                Adj[i,i]=1
                
    return Adj

# Find moral graph and skeleton for a given graph given its adjacency matrix
# Input:  (1) adjacency_matrix: Adjacency matrix of directed graph
# Output: (1) MG: Adjacency matrix of the estimated moral graph
#         (2) Skeleton: Adjacency matrix of the estimated skeleton
def find_moral_graph_and_skeleton_from_adjacency_matrix(Adj):
    nNodes=Adj.shape[0]
    Skeleton=(Adj+Adj.T)
    MG=np.zeros([nNodes,nNodes],dtype=bool)
    MG=np.copy(Skeleton)
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            for ind3 in range(nNodes):
                if (Adj[ind1,ind2]==1) and ((Adj[ind1,ind3]==1)) and ind2!=ind3:
                    MG[ind2,ind3]=1
                    MG[ind3,ind2]=1
    return MG,Skeleton

def find_true_Essential_graph(Skeleton, Adj):
    collider_graph = np.copy(Skeleton)
    collider_nodes = []
    for ind1 in range(num_nodes): # ind1 is the collider node index
        for ind2 in range(num_nodes):    
        #ind2 and ind3 are the indices for the parent of colliders
            for ind3 in range(num_nodes):
                if (Adj[ind1,ind2] == 1) and ((Adj[ind1,ind3] == 1)) and Skeleton[ind2,ind3] !=1 and ind2 != ind3 and ind1 != ind3 and ind1 != ind2:
                    # print( ind1," is a collider")
                    collider_graph[ind2,ind1]=0
                    collider_graph[ind3,ind1]=0
                    collider_nodes.append(ind1)                       
    # % Orienting other edges if possible
    true_ess_graph = np.copy(collider_graph)
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i>j and collider_graph[i,j]==1 and collider_graph[j,i]==1:
                if i in collider_nodes:
                    true_ess_graph[i,j] = 0
                if j in collider_nodes:
                    true_ess_graph[j,i] = 0
                    
    return true_ess_graph

def get_confusion_matrix(true_graph,estimated_graph):
    num_nodes = true_graph.shape[0]
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i!=j:
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 1:
                    TP += 1
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 0:
                    FN += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 1:
                    FP += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 0:
                    TN += 1

    return TP, FP, TN, FN

nNodes = 20
num_nodes = nNodes
Adj = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0]])

MG,Skeleton = find_moral_graph_and_skeleton_from_adjacency_matrix(Adj)

true_ess_graph = find_true_Essential_graph(Skeleton, Adj)

data=np.loadtxt('dataSet_20Nodes.txt', delimiter=' ')


alpha_vec = np.arange(0.01,0.5,0.05)
for alpha_val in alpha_vec:
    start_time = time.time()
    cg = pc(data, alpha=0.1, indep_test='fisherz')
    end_time = time.time()
    est_essential_graph_freq=retrieve_adj(cg)
    
    error_freq_PC = sum(sum(est_essential_graph_freq!=true_ess_graph))/sum(sum(true_ess_graph))
    [TP, FP, TN, FN] = get_confusion_matrix(true_ess_graph, est_essential_graph_freq)
    TPR = TP/(TP+FN)
    FPR = FP/(TN+FP)
    CS = TPR - FPR
    computation_time = end_time-start_time
    print("----------------------------------------______")
    print(f"alpha = {alpha_val}")
    print(f"TPR = {TPR}, FPR = {FPR}")
    print(f"Elapsed algorithm run time = {computation_time}")



# plot_directed_graph(est_essential_graph_freq)