"""
This script generates data for the 20 nodes network with nonlinearity,
then runs the FFT WPC algorithm with selected threshold and sample count then
reports the accuracy and computation time.
"""


import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from itertools import combinations
import time
import random




# %% Frequency domain projection functions

def Wiener_Proj(X,Z):
    # Project X on to Z, Z=a0x_0+a1x1+...am xm, 
    # where m is the number of columns of X
    X_star = X.conj().T
    XX=np.matmul(X_star, X)
    if np.isscalar(XX):
        return np.matmul(np.divide( X_star,XX), Z)
    else:
        return np.matmul(np.matmul(np.linalg.inv(np.matmul(X_star, X)), X_star), Z)


def create_A_B_coeffs_6_nodes(Adj,filter=None):
    # Create A and B matrix for IIR transfer functions
    # A is coeffs of interaction with the past of self, 
    # B interaction with the other nodes 
    # returns A,B
    nNodes=Adj.shape[0]
    
    A = np.array([[1.739381261594696859e-02,	1.043628756956818150e-02,	5.218143784784090751e-03],
                  [3.415208172002098808e-01,	1.366083268800839523e-01,	6.830416344004197615e-02],
                  [8.289102927104288199e-01,	5.802372048973001295e-01,	3.315641170841715502e-01],
                  [8.846579905691620560e-02,	4.423289952845810280e-02,	2.653973971707486099e-02],
                  [8.562817629518102436e-01,	5.137690577710861684e-01,	2.568845288855430842e-01],
                  [1.630296867565628194e-01,	1.141207807295939597e-01,	4.890890602696884581e-02]])

    B=np.random.uniform(0.2,0.4,size=(nNodes,nNodes,3))
    Adj=np.repeat(Adj[:, :, np.newaxis], 3, axis=2)
    B=B*Adj

    return A,B
    

def create_A_B_coeffs_20_nodes(Adj,filter=None):
    # Create A and B matrix for IIR transfer functions
    # A is coeffs of interaction with the past of self, 
    # B interaction with the other nodes 
    # returns A,B
    nNodes=Adj.shape[0]

    
    A = np.array([[1.739381261594696859e-02,	1.043628756956818150e-02,	5.218143784784090751e-03],
                  [3.415208172002098808e-01,	1.366083268800839523e-01,	6.830416344004197615e-02],
                  [8.289102927104288199e-01,	5.802372048973001295e-01,	3.315641170841715502e-01],
                  [8.846579905691620560e-02,	4.423289952845810280e-02,	2.653973971707486099e-02],
                  [8.562817629518102436e-01,	5.137690577710861684e-01,	2.568845288855430842e-01],
                  [1.630296867565628194e-01,	1.141207807295939597e-01,	4.890890602696884581e-02],
                  [1.739381261594696859e-02,	1.043628756956818150e-02,	5.218143784784090751e-03],
                  [3.415208172002098808e-01,	1.366083268800839523e-01,	6.830416344004197615e-02],
                  [8.289102927104288199e-01,	5.802372048973001295e-01,	3.315641170841715502e-01],
                  [8.846579905691620560e-02,	4.423289952845810280e-02,	2.653973971707486099e-02],
                  [8.562817629518102436e-01,	5.137690577710861684e-01,	2.568845288855430842e-01],
                  [1.630296867565628194e-01,	1.141207807295939597e-01,	4.890890602696884581e-02],
                  [1.739381261594696859e-02,	1.043628756956818150e-02,	5.218143784784090751e-03],
                  [3.415208172002098808e-01,	1.366083268800839523e-01,	6.830416344004197615e-02],
                  [8.289102927104288199e-01,	5.802372048973001295e-01,	3.315641170841715502e-01],
                  [8.846579905691620560e-02,	4.423289952845810280e-02,	2.653973971707486099e-02],
                  [8.562817629518102436e-01,	5.137690577710861684e-01,	2.568845288855430842e-01],
                  [1.630296867565628194e-01,	1.141207807295939597e-01,	4.890890602696884581e-02],
                  [1.739381261594696859e-02,	1.043628756956818150e-02,	5.218143784784090751e-03],
                  [3.415208172002098808e-01,	1.366083268800839523e-01,	6.830416344004197615e-02]])

    B=np.random.uniform(0.2,0.4,size=(nNodes,nNodes,3))
    Adj=np.repeat(Adj[:, :, np.newaxis], 3, axis=2)
    B=B*Adj

    return A,B

    
# Generate data according to the VAR model
# Inputs: (1) nSamples: Total number of samples
#         (2) B: p x p x DELAY-1 across other variables
#         (3) A: Matrix with coefficients representing self dynamics
#         (4) noise_pow: Noise power (noise is i.i.d. Gausian)
#         (5) nlags: required lag on other variables
# Output: (1) Timeseries data for a generative graph entailed by A adn B arrays
def continuous_data_generation_arbitrary_lags(nSamples,B,A=None,noise_pow=-1,nlags=0):
    nNodes = B.shape[0]
    indices_to_square = list(range(nNodes))#[0,3,4,6,8,9]
    if noise_pow.all() == -1:
        noise_pow = np.ones(B.shape[0])
    print("Generating continous data.. nSamples: ",nSamples," for nNodes :",nNodes)
    x = np.zeros((nSamples,nNodes))
    print(f"B={B}")
    for ind in range(nlags):
        x[ind,:] = np.random.randn(nNodes)*np.sqrt(noise_pow)
    for ind_Samp in range(nlags,nSamples):
        x[ind_Samp,:] = np.random.randn(nNodes)*np.sqrt(noise_pow) - A[:,0]*x[ind_Samp-1,:] - A[:,1]*x[ind_Samp-2,:] - A[:,2]*x[ind_Samp-3,:]
        x_lagged = np.copy(x[ind_Samp-nLags,:])
        x_non_linear = x_lagged
        x_non_linear[indices_to_square] = np.square(x_non_linear[[indices_to_square]])
        x[ind_Samp,:] = x[ind_Samp,:] + np.dot(B[:,:,0],x[ind_Samp-nLags,:]) + 0.1*np.dot(B[:,:,0],(x_non_linear))
    return x




def compute_Wiener_coeffs(nNodes,data):
    # Returns W and W_mag
    # W: nNodes X nSamples X nNodes
    # W_mag: returns h-infinity norm over frequencies
    nSamples=data.shape[2]
    W=np.zeros([nNodes,nSamples,nNodes],dtype='complex')
    W_mag=np.zeros([nNodes,nNodes])
    for proj_ind in range(nNodes):
        W1=np.zeros([nSamples,nNodes],dtype='complex')
        for freq_index in range(nSamples):
            fft_data=data[:,:,freq_index]
            Z=fft_data[:,proj_ind]
            X_bar=np.delete(fft_data,proj_ind,1)
            W1[freq_index,:]=np.insert(Wiener_Proj(X_bar,Z),proj_ind,0)
        W[proj_ind,:,:]=W1
        W_mag[proj_ind,:]=np.amax(abs(W1),axis=0)
    return W,W_mag




def compute_partial_Wiener_coeffs_hinf(data,i,j,z=[],freq_choice=[]):
    # Returns W and W_mag
    # W: nNodes X n_FFT X nNodes
    # W_mag: returns h-infinity norm over frequencies
    c=np.append(j,np.int32(z))
    indexes = np.unique(c, return_index=True)[1]
    c=[c[index] for index in sorted(indexes)]
    n1=len(c)
    n_FFT=data.shape[2]
    W=np.zeros([n_FFT,n1],dtype='complex')
    W_mag=np.zeros([n1])
    for freq_index in freq_choice:#range(n_FFT):
        fft_data=data[:,:,freq_index]
        X=fft_data[:,i]
        X_bar=fft_data[:,c]
        W[freq_index,:]=Wiener_Proj(X_bar,X)
    W_mag=np.amax(abs(W),axis=0)
    return W,W_mag

def compute_partial_Wiener_coeffs_avg(data,i,j,z=[],freq_choice=[]):
    # Returns W and W_mag
    # W: nNodes X n_FFT X nNodes
    # W_mag: returns h-infinity norm over frequencies
    c=np.append(j,np.int32(z))
    # c=np.unique(c)
    indexes = np.unique(c, return_index=True)[1]
    c=[c[index] for index in sorted(indexes)]
    n1=len(c)
    n_FFT=data.shape[2]
    W=np.zeros([n_FFT,n1],dtype='complex')
    # W_mag=np.zeros([n1])
    W_mag_avg = np.zeros([n1],dtype=float)
    for freq_index in freq_choice:
        fft_data=data[:,:,freq_index]
        # print(fft_data.shape)
        # print(freq_index)
        X=fft_data[:,i]
        X_bar=fft_data[:,c]
        # print(Z.shape,X_bar.shape)
        # print(X_bar[:3,:])
        W[freq_index,:]=Wiener_Proj(X_bar,X)
        W_mag_avg = W_mag_avg + np.abs(W[freq_index,:])
    W_mag_avg = np.divide(W_mag_avg,len(freq_choice))
    # W_mag=np.amax(abs(W),axis=0)
    return W,W_mag_avg

# Compute the FFT of the data
# Inputs:  (1) data: time-series data
#          (2) nfft: numer of FFT points
# Output:  (1) y: the FFT of the provided data samples
def compute_fft(data,nfft):
    nSamples,nNodes=data.shape
    nTrajectories=np.int32(nSamples/nfft) 
    y=np.zeros((nTrajectories,nNodes,nfft),dtype=complex)
    for ind in range(nTrajectories-1): # discard final few residual samples
        x=data[ind*(nfft):(ind+1)*nfft,:]
        X=np.fft.fft(x,axis=0) #/np.sqrt(nfft)
        y[ind,:,:]=np.transpose(X) #X is (nfft,nNodes)
    return y




def Run_PC_Wiener_test_FFT_avg(data,nFFT,threshd,Q):
    # print("Run_PC_Wiener_test_FFT_avg")
    nNodes=data.shape[1]
    C={}
    data_FFT = compute_fft(data,nfft=nFFT)
    freq_choice = range(nFFT)
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            if ind1>ind2:
                # print('Chocsen node pair: [',ind1+1,',',ind2+1,']')
                flag=False
                # the complement set through which we condition
                D=[i for i in range(nNodes) if i != ind1 and i !=ind2]
                # print("D ",np.array(D)+1)
                # iterate over the combinations of various size in the increasing order
                for ind_cond in range(nNodes-1):
                    if flag==True:
                        break
                    combination_set=[i for i in combinations(D,ind_cond)]
                    # print("combination"+str(ind_cond)+" =",combination_set)
                    for c in combination_set:
                        # print("\n combination set: ",np.array(c)+1)
                        # # The partial Wiener coefficients does not seem symmetric wrt the projection. So, taking the average of two.
                        W_i_jZ1,W_mag_i_jZ1=compute_partial_Wiener_coeffs_avg(data_FFT,ind1,ind2,c,freq_choice)
                        
                        W_mag_i_jZ=W_mag_i_jZ1
                        
                        # print("Projection coefficient of X"+str(ind1+1)+" on "+str(np.append(ind2,c)+1)+"=",W_mag_i_jZ)
                        if W_mag_i_jZ[0]<threshd:
                            # print("Separation set of ( X"+str(ind1+1)+", X"+str(ind2+1),") is : ",np.array(c)+1)
                            # c=[l+1 for l in c]
                            C[ind1,ind2]=c
                            flag=True
                            break
                            c=np.array([ind1])
                        if np.shape(c)[0]==Q:#nNodes-2:
                            # print("No separation set for ( X"+str(ind1+1)+", X"+str(ind2+1),")")
                            C[ind1,ind2]=None
                            flag=True
                            break
            
    return C


################################################################################################################################
def data_processing(original_data_raw,nSamples,n_FFT,nNodes):
    y = np.zeros((nSamples,nNodes,n_FFT),dtype=complex)
    temp_ind = 0
    for ind_Traj in range(nSamples):
        x = original_data_raw[temp_ind:temp_ind+n_FFT:1,:]
        temp_ind = temp_ind + n_FFT
        X = np.fft.fft(x,axis=0)
        y[ind_Traj,:,:] = np.transpose(X)
    return y


def normalize_data(data):
    col_means = np.mean(data,axis=0)
    col_stds = np.std(data,axis=0)
    normalizedData = (data - col_means)/col_stds
    return normalizedData

# %% Essential graph estimation function

def estimate_essential_graph(Top_est,C):
    colliders = list()
    Est_adj = np.copy(Top_est)
    # # Estimate MEG from skeleton and d-separating set, C
    for ind1 in range(num_nodes):
        for ind2 in range(num_nodes):   
            for ind3 in range(ind2+1,num_nodes):
                # check if ind3 is present in the d-separating set of (ind1,ind2) 
                # is present in skeleton when ind3--ind1 and ind3--ind2 
                # are present in the skeleton. Not present means ind3 is a collider
                # If not present, then ind3 is a collider
                if (Top_est[ind1,ind2]==1) and (Top_est[ind1,ind3]==1) and ind1!=ind3 and ind2!=ind1:
                    print(ind1,ind2,ind3,C[ind1,ind2])
                    if C[(ind3,ind2)]!=None: #None means no d-separating set
                        if (ind1 not in C[(ind3,ind2)]):
                            print(ind1, " is a collider")
                            colliders.append(ind1)
                            Est_adj[ind2,ind1] = 0
                            Est_adj[ind3,ind1] = 0
                    # else:
                    #     print(ind1, " :: is a collider")
                    #     Est_adj[ind2,ind1]=0
                    #     Est_adj[ind3,ind1]=0
                    #     colliders.append(ind1)
                    # pass
    colliders=list(set(colliders))
    return Est_adj,colliders

def estimate_essential_graph_new(Top_est,C):
    
    #% Collider Identification
    collider_graph = np.copy(Top_est)
    collider_set = []
    for i in range(nNodes):
        for j in range(nNodes):
            for k in range(nNodes):
                if i>j and j!=k and i!=k and Top_est[i,k]==1 and Top_est[j,k]==1 and Top_est[i,j]!=1 and (C[(i,j)]!=None) and (k not in C[(i,j)]):
                    collider_graph[i,k] = 0
                    collider_graph[j,k] = 0
                    collider_set.append(k)


    # % Orienting other edges if possible
    final_graph = np.copy(collider_graph)
    for i in range(nNodes):
        for j in range(nNodes):
            if i>j and collider_graph[i,j]==1 and collider_graph[j,i]==1:
                if i in collider_set:
                    final_graph[i,j] = 0
                if j in collider_set:
                    final_graph[j,i] = 0

    
    return final_graph,collider_graph


def get_confusion_matrix(true_graph,estimated_graph):
    num_nodes = true_graph.shape[0]
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i!=j:
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 1:
                    TP += 1
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 0:
                    FN += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 1:
                    FP += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 0:
                    TN += 1

    return TP, FP, TN, FN


# Find moral graph and skeleton for a given graph given its adjacency matrix
# Input:  (1) adjacency_matrix: Adjacency matrix of directed graph
# Output: (1) MG: Adjacency matrix of the estimated moral graph
#         (2) Skeleton: Adjacency matrix of the estimated skeleton
def find_moral_graph_and_skeleton_from_adjacency_matrix(Adj):
    nNodes=Adj.shape[0]
    Skeleton=(Adj+Adj.T)
    MG=np.zeros([nNodes,nNodes],dtype=bool)
    MG=np.copy(Skeleton)
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            for ind3 in range(nNodes):
                if (Adj[ind1,ind2]==1) and ((Adj[ind1,ind3]==1)) and ind2!=ind3:
                    MG[ind2,ind3]=1
                    MG[ind3,ind2]=1
    return MG,Skeleton

def find_true_Essential_graph(Skeleton, Adj):
    collider_graph = np.copy(Skeleton)
    collider_nodes = []
    for ind1 in range(num_nodes): # ind1 is the collider node index
        for ind2 in range(num_nodes):    
        #ind2 and ind3 are the indices for the parent of colliders
            for ind3 in range(num_nodes):
                if (Adj[ind1,ind2] == 1) and ((Adj[ind1,ind3] == 1)) and Skeleton[ind2,ind3] !=1 and ind2 != ind3 and ind1 != ind3 and ind1 != ind2:
                    # print( ind1," is a collider")
                    collider_graph[ind2,ind1]=0
                    collider_graph[ind3,ind1]=0
                    collider_nodes.append(ind1)                       
    # % Orienting other edges if possible
    true_ess_graph = np.copy(collider_graph)
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i>j and collider_graph[i,j]==1 and collider_graph[j,i]==1:
                if i in collider_nodes:
                    true_ess_graph[i,j] = 0
                if j in collider_nodes:
                    true_ess_graph[j,i] = 0
                    
    return true_ess_graph



# Find moral graph and skeleton for a given graph given its adjacency matrix
# Input:  (1) adjacency_matrix: Adjacency matrix of directed graph
# Output: (1) MG: Adjacency matrix of the estimated moral graph
#         (2) Skeleton: Adjacency matrix of the estimated skeleton
def find_moral_graph_and_skeleton_from_adjacency_matrix(Adj):
    nNodes=Adj.shape[0]
    Skeleton=(Adj+Adj.T)
    MG=np.zeros([nNodes,nNodes],dtype=bool)
    MG=np.copy(Skeleton)
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            for ind3 in range(nNodes):
                if (Adj[ind1,ind2]==1) and ((Adj[ind1,ind3]==1)) and ind2!=ind3:
                    MG[ind2,ind3]=1
                    MG[ind3,ind2]=1
    return MG,Skeleton

def find_true_Essential_graph(Skeleton, Adj):
    collider_graph = np.copy(Skeleton)
    collider_nodes = []
    for ind1 in range(num_nodes): # ind1 is the collider node index
        for ind2 in range(num_nodes):    
        #ind2 and ind3 are the indices for the parent of colliders
            for ind3 in range(num_nodes):
                if (Adj[ind1,ind2] == 1) and ((Adj[ind1,ind3] == 1)) and Skeleton[ind2,ind3] !=1 and ind2 != ind3 and ind1 != ind3 and ind1 != ind2:
                    # print( ind1," is a collider")
                    collider_graph[ind2,ind1]=0
                    collider_graph[ind3,ind1]=0
                    collider_nodes.append(ind1)                       
    # % Orienting other edges if possible
    true_ess_graph = np.copy(collider_graph)
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i>j and collider_graph[i,j]==1 and collider_graph[j,i]==1:
                if i in collider_nodes:
                    true_ess_graph[i,j] = 0
                if j in collider_nodes:
                    true_ess_graph[j,i] = 0
                    
    return true_ess_graph

# Main function ##################################################################################################################
##################################################################################################################################
start_sim_time=time.time()
print("Start time: ",start_sim_time)
nSamples_list = [128000]
num_nodes = 20
max_in_degree = 2
max_out_degree = 1
noise_pow1 = 1*np.ones(num_nodes)
avg_deg_lim = [2.1,3.5]
max_deg_lim = 9
min_deg_lim = 1
random.seed(3)
np.random.seed(3)
nLags = 3
nNodes = num_nodes
num_nodes = nNodes
n_FFT = 32



########### ########### ########### ########### ########### ########### ########### 
# Network and Data Generation
########### ########### ########### ########### ########### ########### ###########
Adj = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0]])


MG,Skeleton = find_moral_graph_and_skeleton_from_adjacency_matrix(Adj)
A1,B1 = create_A_B_coeffs_20_nodes(Adj)

true_ess_graph = find_true_Essential_graph(Skeleton, Adj)
########### ########### Generate data ############################################  
start_data=time.time()
print("Data generation started at ", start_data)
data = continuous_data_generation_arbitrary_lags(nSamples_list[-1],B=B1,A=A1, noise_pow=noise_pow1,nlags=nLags)
print("Time taken to generate data = ", time.time()-start_data)

# np.savetxt("dataSet_20Nodes_phase_violation.txt", data)
        
plt.plot(np.arange(0,data.shape[0]), data[:,0])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,1])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,2])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,3])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,4])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,5])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,6])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,7])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,8])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,9])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,10])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,11])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,12])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,13])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,14])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,15])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,16])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,17])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,18])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,19])
plt.show()

# %% Set the nSample to be maximum possible samples depending on the size of the dataset

nSamples = data.shape[0]//n_FFT

############################################################################################################################################################  
############################################################################################################################################################
freq_choice = range(n_FFT)
# PC test bes
thresh_freq = np.array([0.04])
error_freq_PC = np.zeros(len(thresh_freq))
TPR = np.zeros(len(thresh_freq))
FPR = np.zeros(len(thresh_freq))
computation_time = np.zeros(len(thresh_freq))

MG,Skeleton = find_moral_graph_and_skeleton_from_adjacency_matrix(Adj)
true_ess_graph = find_true_Essential_graph(Skeleton, Adj)

ind_thresh = 0
start_time = time.time()
C_freq = Run_PC_Wiener_test_FFT_avg(data, n_FFT, thresh_freq[ind_thresh],Q=5)
Top_est_half = np.zeros([num_nodes,num_nodes])
for c in C_freq:
    if C_freq[c] == None: Top_est_half[c]=1
Top_est = (Top_est_half + Top_est_half.T)
# est_essential_graph_freq,est_colliders_freq = estimate_essential_graph(Top_est_half,C_freq)
est_essential_graph_freq,est_colliders_freq = estimate_essential_graph_new(Top_est,C_freq)
end_time = time.time()

error_freq_PC[ind_thresh] = sum(sum(est_essential_graph_freq!=true_ess_graph))/sum(sum(true_ess_graph))

[TP, FP, TN, FN] = get_confusion_matrix(true_ess_graph, est_essential_graph_freq)
TPR[ind_thresh] = TP/(TP+FN)
FPR[ind_thresh] = FP/(TN+FP)
computation_time[ind_thresh] = end_time - start_time

# TPR_paper[ind_thresh] = TP/(TP+FP)
# FPR_paper[ind_thresh] = FP/(TN+FP)
print("----------------------------------------______")
print('Threshold=', thresh_freq[ind_thresh])
print(f"TPR = {TPR[ind_thresh]}, FPR = {FPR[ind_thresh]}, CS = {TPR[ind_thresh]-FPR[ind_thresh]}")
    


CS = TPR - FPR



