import networkx as nx
import numpy as np
from numpy.random import default_rng
rng = default_rng(seed=111)
from timeawarepc.tpc import cfc_tpc, cfc_pc
from timeawarepc.gc import cfc_gc
from timeawarepc.simulate_data import *
from timeawarepc.find_cfc import *
import time


def get_confusion_matrix(true_graph,estimated_graph):
    num_nodes = true_graph.shape[0]
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i!=j:
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 1:
                    TP += 1
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 0:
                    FN += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 1:
                    FP += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 0:
                    TN += 1

    return TP, FP, TN, FN

# Find moral graph and skeleton for a given graph given its adjacency matrix
# Input:  (1) adjacency_matrix: Adjacency matrix of directed graph
# Output: (1) MG: Adjacency matrix of the estimated moral graph
#         (2) Skeleton: Adjacency matrix of the estimated skeleton
def find_moral_graph_and_skeleton_from_adjacency_matrix(Adj):
    nNodes=Adj.shape[0]
    Skeleton=(Adj+Adj.T)
    MG=np.zeros([nNodes,nNodes],dtype=bool)
    MG=np.copy(Skeleton)
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            for ind3 in range(nNodes):
                if (Adj[ind1,ind2]==1) and ((Adj[ind1,ind3]==1)) and ind2!=ind3:
                    MG[ind2,ind3]=1
                    MG[ind3,ind2]=1
    return MG,Skeleton

def find_true_Essential_graph(Skeleton, Adj):
    num_nodes = Adj.shape[0]
    collider_graph = np.copy(Skeleton)
    collider_nodes = []
    for ind1 in range(num_nodes): # ind1 is the collider node index
        for ind2 in range(num_nodes):    
        #ind2 and ind3 are the indices for the parent of colliders
            for ind3 in range(num_nodes):
                if (Adj[ind1,ind2] == 1) and ((Adj[ind1,ind3] == 1)) and Skeleton[ind2,ind3] !=1 and ind2 != ind3 and ind1 != ind3 and ind1 != ind2:
                    # print( ind1," is a collider")
                    collider_graph[ind2,ind1]=0
                    collider_graph[ind3,ind1]=0
                    collider_nodes.append(ind1)                       
    # % Orienting other edges if possible
    true_ess_graph = np.copy(collider_graph)
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i>j and collider_graph[i,j]==1 and collider_graph[j,i]==1:
                if i in collider_nodes:
                    true_ess_graph[i,j] = 0
                if j in collider_nodes:
                    true_ess_graph[j,i] = 0
                    
    return true_ess_graph


data = np.loadtxt('FourNodesBJTHardwareData1M.txt', dtype='d', delimiter=',')

Adj = np.array([[0, 0, 0, 0],
                [0, 0, 0, 0],
                [1, 1, 0, 0],
                [0, 0, 1, 0]])

MG,Skeleton = find_moral_graph_and_skeleton_from_adjacency_matrix(Adj)

true_ess_graph_t = find_true_Essential_graph(Skeleton, Adj)

true_ess_graph = true_ess_graph_t.T



CS_prev = 0
alpha_best = 0
alpha_arr = np.arange(0.01,0.5,0.05)
for alpha_num in alpha_arr:
    start_time = time.time()
    adjmat = cfc_tpc(data[:128000,:], maxdelay=5, subsampsize=50, niter=50, alpha=alpha_num, thresh=0.01, isgauss=True)
    end_time = time.time()
    elapsed_time = end_time - start_time
    

    [TP, FP, TN, FN] = get_confusion_matrix(true_ess_graph, adjmat[0])
    TPR = TP/(TP+FN)
    FPR = FP/(TN+FP)
    CS = TPR - FPR
    precision = TP / (TP + FP)
    recall    = TP / (TP + FN)
    f1        = (2 * precision * recall) /(precision + recall)


    print("-----------------------------------")
    print(f"alpha = {alpha_num}")
    print('F1 Score = ',f1)
    print(f"TPR = {TPR}")
    print(f"FPR = {FPR}")
    print(f"Total algorithm run time = {elapsed_time}")
    if CS > CS_prev:
        alpha_best = alpha_num
        CS_prev = CS

print(f"Best alpha = {alpha_best}")