"""
This script generates data for the 20 nodes network where the dynamics are nonlinear,
then runs the wiener phase algorithm with selected threshold and sample count then
reports the accuracy and computation time.
"""


import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import time
import random





def generate_graph(num_nodes,max_in_degree):
    # random_dag = generate_random_dag(num_nodes, avg_degree)
    # generates a random dag with maximum in degree equal to max_in_degree
    G = nx.DiGraph()
    nodes = list(range(num_nodes))
    G.add_nodes_from(node for node in nodes)
    for node in nodes:
        # in_degree = max_in_degree 
        in_degree = random.randint(1, max_in_degree)
        if node - in_degree < 0:
            in_degree = node  # Avoid in-degrees exceeding node indices
        in_neighbors = random.sample(nodes[:node], in_degree)
        G.add_edges_from((neighbor, node) for neighbor in in_neighbors)
    Adj = nx.adjacency_matrix(G).toarray().T
    print("Number of nodes:", G.number_of_nodes())
    print("Number of edges:", G.number_of_edges())
    # print("Adjacency matrix:", Adj.shape," \n", Adj)
    return G,Adj

    

def create_A_B_coeffs_6_nodes(Adj,filter=None):
    # Create A and B matrix for IIR transfer functions
    # A is coeffs of interaction with the past of self, 
    # B interaction with the other nodes 
    # returns A,B
    nNodes=Adj.shape[0]
    
    A = np.array([[1.739381261594696859e-02,	1.043628756956818150e-02,	5.218143784784090751e-03],
                  [3.415208172002098808e-01,	1.366083268800839523e-01,	6.830416344004197615e-02],
                  [8.289102927104288199e-01,	5.802372048973001295e-01,	3.315641170841715502e-01],
                  [8.846579905691620560e-02,	4.423289952845810280e-02,	2.653973971707486099e-02],
                  [8.562817629518102436e-01,	5.137690577710861684e-01,	2.568845288855430842e-01],
                  [1.630296867565628194e-01,	1.141207807295939597e-01,	4.890890602696884581e-02]])

    B=np.random.uniform(0.2,0.4,size=(nNodes,nNodes,3))

    Adj=np.repeat(Adj[:, :, np.newaxis], 3, axis=2)
    B=B*Adj
    B[:,:,1:3]=0
    return A,B



def create_A_B_coeffs_20_nodes(Adj,filter=None):
    # Create A and B matrix for IIR transfer functions
    # A is coeffs of interaction with the past of self, 
    # B interaction with the other nodes 
    # returns A,B
    nNodes=Adj.shape[0]

    
    A = np.array([[1.739381261594696859e-02,	1.043628756956818150e-02,	5.218143784784090751e-03],
                  [3.415208172002098808e-01,	1.366083268800839523e-01,	6.830416344004197615e-02],
                  [8.289102927104288199e-01,	5.802372048973001295e-01,	3.315641170841715502e-01],
                  [8.846579905691620560e-02,	4.423289952845810280e-02,	2.653973971707486099e-02],
                  [8.562817629518102436e-01,	5.137690577710861684e-01,	2.568845288855430842e-01],
                  [1.630296867565628194e-01,	1.141207807295939597e-01,	4.890890602696884581e-02],
                  [1.739381261594696859e-02,	1.043628756956818150e-02,	5.218143784784090751e-03],
                  [3.415208172002098808e-01,	1.366083268800839523e-01,	6.830416344004197615e-02],
                  [8.289102927104288199e-01,	5.802372048973001295e-01,	3.315641170841715502e-01],
                  [8.846579905691620560e-02,	4.423289952845810280e-02,	2.653973971707486099e-02],
                  [8.562817629518102436e-01,	5.137690577710861684e-01,	2.568845288855430842e-01],
                  [1.630296867565628194e-01,	1.141207807295939597e-01,	4.890890602696884581e-02],
                  [1.739381261594696859e-02,	1.043628756956818150e-02,	5.218143784784090751e-03],
                  [3.415208172002098808e-01,	1.366083268800839523e-01,	6.830416344004197615e-02],
                  [8.289102927104288199e-01,	5.802372048973001295e-01,	3.315641170841715502e-01],
                  [8.846579905691620560e-02,	4.423289952845810280e-02,	2.653973971707486099e-02],
                  [8.562817629518102436e-01,	5.137690577710861684e-01,	2.568845288855430842e-01],
                  [1.630296867565628194e-01,	1.141207807295939597e-01,	4.890890602696884581e-02],
                  [1.739381261594696859e-02,	1.043628756956818150e-02,	5.218143784784090751e-03],
                  [3.415208172002098808e-01,	1.366083268800839523e-01,	6.830416344004197615e-02]])

    B=np.random.uniform(0.2,0.4,size=(nNodes,nNodes,3))
    Adj=np.repeat(Adj[:, :, np.newaxis], 3, axis=2)
    B=B*Adj

    return A,B



# Generate data according to the VAR model
# Inputs: (1) nSamples: Total number of samples
#         (2) B: p x p x DELAY-1 across other variables
#         (3) A: Matrix with coefficients representing self dynamics
#         (4) noise_pow: Noise power (noise is i.i.d. Gausian)
#         (5) nlags: required lag on other variables
# Output: (1) Timeseries data for a generative graph entailed by A adn B arrays
def continuous_data_generation_arbitrary_lags(nSamples,B,A=None,noise_pow=-1,nlags=0):
    nNodes = B.shape[0]
    indices_to_square = list(range(nNodes))#[0,3,4,6,8,9]
    if noise_pow.all() == -1:
        noise_pow = np.ones(B.shape[0])
    print("Generating continous data.. nSamples: ",nSamples," for nNodes :",nNodes)
    x = np.zeros((nSamples,nNodes))
    print(f"B={B}")
    for ind in range(nlags):
        x[ind,:] = np.random.randn(nNodes)*np.sqrt(noise_pow)
    for ind_Samp in range(nlags,nSamples):
        x[ind_Samp,:] = np.random.randn(nNodes)*np.sqrt(noise_pow) - A[:,0]*x[ind_Samp-1,:] - A[:,1]*x[ind_Samp-2,:] - A[:,2]*x[ind_Samp-3,:]
        x_lagged = np.copy(x[ind_Samp-nLags,:])
        x_non_linear = x_lagged
        x_non_linear[indices_to_square] = np.square(x_non_linear[[indices_to_square]])
        x[ind_Samp,:] = x[ind_Samp,:] + np.dot(B[:,:,0],x[ind_Samp-nLags,:]) + 0.1*np.dot(B[:,:,0],(x_non_linear))
    return x



def find_skeleton_from_ess(Adj):
    nNodes=Adj.shape[0]
    Skeleton=np.zeros([nNodes,nNodes],dtype=bool)
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            if ind1 != ind2:
                if (Adj[ind1,ind2]==1) or ((Adj[ind2,ind1]==1)):
                    Skeleton[ind1,ind2]=1
                    Skeleton[ind2,ind1]=1
    return Skeleton






def compute_undir_graph_degrees(Adj):
    nNodes = Adj.shape[0]
    node_degs = []
    for node in range(nNodes):
        node_degs.append(sum(Adj[:,node]))
    max_deg = max(node_degs)
    min_deg = min(node_degs)
    avg_deg = sum(node_degs)/nNodes
    
    return min_deg, max_deg, avg_deg


def plot_custom_graph(adj_matrix):
    # Create a directed graph from the adjacency matrix
    G = nx.from_numpy_array(adj_matrix, create_using=nx.DiGraph)

    # Rename nodes to 'V1', 'V2', ..., 'V20'
    mapping = {i: f'V{i+1}' for i in range(len(adj_matrix))}
    G = nx.relabel_nodes(G, mapping)

    # Use a layout that avoids overlapping nodes and messy edges
    pos = nx.spring_layout(G)  # seed for consistent layout
    # pos = nx.kamada_kawai_layout(G)

    # Draw nodes, edges, and labels
    plt.figure(figsize=(10, 8))
    nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=800, arrowsize=20, font_size=10)
    nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): '' for u, v in G.edges()}, font_color='gray')
    plt.title("Graph from Custom Adjacency Matrix")
    plt.axis('off')
    plt.tight_layout()
    plt.show()




# Compute the FFT of the data
# Inputs:  (1) data: time-series data
#          (2) nfft: numer of FFT points
# Output:  (1) y: the FFT of the provided data samples
def compute_fft(data,nfft):
    nSamples,nNodes=data.shape
    nTrajectories=np.int32(nSamples/nfft) 
    y=np.zeros((nTrajectories,nNodes,nfft),dtype=complex)
    for ind in range(nTrajectories-1): # discard final few residual samples
        x=data[ind*(nfft):(ind+1)*nfft,:]
        X=np.fft.fft(x,axis=0) #/np.sqrt(nfft)
        y[ind,:,:]=np.transpose(X) #X is (nfft,nNodes)
    return y



# Compute least squates solution: Project X on to Z, Z=a0x_0+a1x1+...am xm, where m is the number of columns of X
def Least_Squares_Solution(X,Z):
    X_star = X.conj().T
    XX=np.matmul(X_star, X)
    if np.isscalar(XX):
        return np.matmul(np.divide( X_star,XX), Z)
    else:
        return np.matmul(np.matmul(np.linalg.inv(np.matmul(X_star, X)), X_star), Z)



def Wiener_Proj(X,Z):
    # Project X on to Z, Z=a0x_0+a1x1+...am xm, 
    # where m is the number of columns of X
    X_star = X.conj().T
    XX=np.matmul(X_star, X)
    if np.isscalar(XX):
        return np.matmul(np.divide( X_star,XX), Z)
    else:
        return np.matmul(np.matmul(np.linalg.inv(np.matmul(X_star, X)), X_star), Z)



def compute_Wiener_coeffs(nNodes,data):
    # Returns W and W_mag
    # W: nNodes X nSamples X nNodes
    # W_mag: returns h-infinity norm over frequencies
    nSamples=data.shape[2]
    W=np.zeros([nNodes,nSamples,nNodes],dtype='complex')
    W_mag=np.zeros([nNodes,nNodes])
    for proj_ind in range(nNodes):
        W1=np.zeros([nSamples,nNodes],dtype='complex')
        for freq_index in range(nSamples):
            fft_data=data[:,:,freq_index]
            Z=fft_data[:,proj_ind]
            X_bar=np.delete(fft_data,proj_ind,1)
            W1[freq_index,:]=np.insert(Wiener_Proj(X_bar,Z),proj_ind,0)
        W[proj_ind,:,:]=W1
        W_mag[proj_ind,:]=np.amax(abs(W1),axis=0)
    return W,W_mag



# Compute the partial Wiener coefficients using FFT based approach
# Inputs:  (1) data: time-series data.
#          (2) i: First node.
#          (3) j: Second node.
#          (4) z: Conditioning set of nodes.
# Outputs: (1) W: The estimated Wiener filters
#          (2) W_mag: Maximum magnitude of the Wiener filters taken over the frequencies
def compute_partial_Wiener_coeffs_FFT(data,i,j,z=[]):
    c=np.append(j,np.int32(z))
    indexes = np.unique(c, return_index=True)[1]
    c=[c[index] for index in sorted(indexes)]
    n1=len(c)
    nFFT=data.shape[2]
    W=np.zeros([nFFT,n1],dtype='complex')
    W_mag=np.zeros([n1])
    for freq_index in range(nFFT):
        fft_data=data[:,:,freq_index]
        X=fft_data[:,i]
        X_bar=fft_data[:,c]
        W[freq_index,:]=Least_Squares_Solution(X_bar,X)
    W_mag=np.amax(abs(W),axis=0)
    return W,W_mag



def compute_partial_Wiener_coeffs_avg(data,i,j,z=[],freq_choice=[]):
    # Returns W and W_mag
    # W: nNodes X n_FFT X nNodes
    # W_mag: returns h-infinity norm over frequencies
    c=np.append(j,np.int32(z))
    # c=np.unique(c)
    indexes = np.unique(c, return_index=True)[1]
    c=[c[index] for index in sorted(indexes)]
    n1=len(c)
    n_FFT=data.shape[2]
    W=np.zeros([n_FFT,n1],dtype='complex')
    # W_mag=np.zeros([n1])
    W_mag_avg = np.zeros([n1],dtype=float)
    for freq_index in freq_choice:
        fft_data=data[:,:,freq_index]
        # print(fft_data.shape)
        # print(freq_index)
        X=fft_data[:,i]
        X_bar=fft_data[:,c]
        # print(Z.shape,X_bar.shape)
        # print(X_bar[:3,:])
        W[freq_index,:]=Wiener_Proj(X_bar,X)
        W_mag_avg = W_mag_avg + np.abs(W[freq_index,:])
    W_mag_avg = np.divide(W_mag_avg,len(freq_choice))
    # W_mag=np.amax(abs(W),axis=0)
    return W,W_mag_avg



def Run_Wiener_Phase_avg(data,nFFT,thresh=0.1):
    nNodes=data.shape[1]
    data_FFT=compute_fft(data,nfft=nFFT)
    W,Wmag=compute_Wiener_coeffs(nNodes, data_FFT)
    
    
    
    # Moral Graph estimation using absolute value of average magnitude of Wiener projection over frequencies
    W_mag_avg = np.zeros((nNodes,nNodes),dtype=float)

    freq_choice = range(nFFT)
    for i in freq_choice:
        W_mag_avg = W_mag_avg + np.abs(W[:,i,:])
        
    W_mag_avg = np.divide(W_mag_avg,len(freq_choice))
    W_mag_avg = np.divide(W_mag_avg+np.transpose(W_mag_avg),2)
    moral_graph_adj=W_mag_avg>thresh
    
    # Skeleton estimation using h-inf of imaginary part of Wiener projection coefficients
    W_imag=W.imag
    W_imag_avg = np.zeros((nNodes,nNodes),dtype=float)
    freq_choice = range(nFFT)
    for i in freq_choice:
        W_imag_avg = W_imag_avg + np.abs(W_imag[:,i,:])
        
    W_imag_avg = np.divide(W_imag_avg,len(freq_choice))
    W_imag_avg = np.divide(W_imag_avg+np.transpose(W_imag_avg),2)
    skeleton_adj=W_imag_avg>thresh
    
    # Essential graph estimation
    spurious_adj=abs(moral_graph_adj^skeleton_adj)
    C={}
    collider_graph=np.zeros_like(skeleton_adj)
    collider_graph[:]=skeleton_adj[:]
    collider_set = []
    for i in range(nNodes):
        for j in range(nNodes):
            if i>j and spurious_adj[i,j]==1:
                potential_colider=[]
                for k in range(nNodes):
                    if skeleton_adj[i,k]==1 and skeleton_adj[j,k]==1:
                        potential_colider.append(k)
                C[i,j]=potential_colider
                if len(C[i,j])==1:
                    collider_graph[i,C[i,j]]=False
                    collider_graph[j,C[i,j]]=False
                    collider_set.append(C[i,j][0])
                else:
                    for l in C[i,j]:
                        W_ij_c,W_ij_c_mag=compute_partial_Wiener_coeffs_avg(data_FFT, i, j,l,freq_choice = range(nFFT))
                        if W_ij_c_mag[0]>thresh:
                            collider_graph[i,l]=False
                            collider_graph[j,l]=False
                            collider_set.append(l)
                            
                            
    # % Orienting other edges if possible
    final_graph = np.copy(collider_graph)
    for i in range(nNodes):
        for j in range(nNodes):
            if i>j and collider_graph[i,j]==1 and collider_graph[j,i]==1:
                if i in collider_set:
                    final_graph[i,j] = 0
                if j in collider_set:
                    final_graph[j,i] = 0
                        
    
    return final_graph




def data_processing(original_data_raw,nSamples,n_FFT,nNodes):
    y = np.zeros((nSamples,nNodes,n_FFT),dtype=complex)
    temp_ind = 0
    for ind_Traj in range(nSamples):
        x = original_data_raw[temp_ind:temp_ind+n_FFT:1,:]
        temp_ind = temp_ind + n_FFT
        X = np.fft.fft(x,axis=0)
        y[ind_Traj,:,:] = np.transpose(X)
    return y


def get_confusion_matrix(true_graph,estimated_graph):
    num_nodes = true_graph.shape[0]
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i!=j:
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 1:
                    TP += 1
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 0:
                    FN += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 1:
                    FP += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 0:
                    TN += 1

    return TP, FP, TN, FN


# Find moral graph and skeleton for a given graph given its adjacency matrix
# Input:  (1) adjacency_matrix: Adjacency matrix of directed graph
# Output: (1) MG: Adjacency matrix of the estimated moral graph
#         (2) Skeleton: Adjacency matrix of the estimated skeleton
def find_moral_graph_and_skeleton_from_adjacency_matrix(Adj):
    nNodes=Adj.shape[0]
    Skeleton=(Adj+Adj.T)
    MG=np.zeros([nNodes,nNodes],dtype=bool)
    MG=np.copy(Skeleton)
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            for ind3 in range(nNodes):
                if (Adj[ind1,ind2]==1) and ((Adj[ind1,ind3]==1)) and ind2!=ind3:
                    MG[ind2,ind3]=1
                    MG[ind3,ind2]=1
    return MG,Skeleton

def find_true_Essential_graph(Skeleton, Adj):
    collider_graph = np.copy(Skeleton)
    collider_nodes = []
    for ind1 in range(num_nodes): # ind1 is the collider node index
        for ind2 in range(num_nodes):    
        #ind2 and ind3 are the indices for the parent of colliders
            for ind3 in range(num_nodes):
                if (Adj[ind1,ind2] == 1) and ((Adj[ind1,ind3] == 1)) and Skeleton[ind2,ind3] !=1 and ind2 != ind3 and ind1 != ind3 and ind1 != ind2:
                    # print( ind1," is a collider")
                    collider_graph[ind2,ind1]=0
                    collider_graph[ind3,ind1]=0
                    collider_nodes.append(ind1)                       
    # % Orienting other edges if possible
    true_ess_graph = np.copy(collider_graph)
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i>j and collider_graph[i,j]==1 and collider_graph[j,i]==1:
                if i in collider_nodes:
                    true_ess_graph[i,j] = 0
                if j in collider_nodes:
                    true_ess_graph[j,i] = 0
                    
    return true_ess_graph

# Main function ##################################################################################################################
##################################################################################################################################
start_sim_time=time.time()
print("Start time: ",start_sim_time)
nSamples_list = [128000]
num_nodes = 20
max_in_degree = 2
max_out_degree = 1
noise_pow1 = 1*np.ones(num_nodes)
avg_deg_lim = [2.1,3.5]
max_deg_lim = 9
min_deg_lim = 1
random.seed(3)
np.random.seed(3)
nLags = 3
nNodes = num_nodes
num_nodes = nNodes
nFFT = 32



########### ########### ########### ########### ########### ########### ########### 
# Network and Data Generation
########### ########### ########### ########### ########### ########### ###########
Adj = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0]])



MG,Skeleton = find_moral_graph_and_skeleton_from_adjacency_matrix(Adj)
A1,B1 = create_A_B_coeffs_20_nodes(Adj)

true_ess_graph = find_true_Essential_graph(Skeleton, Adj)
min_degree, max_degree, avg_degree = compute_undir_graph_degrees(Skeleton)
print(f"Max Degree = {max_degree}")
print(f"Min Degree = {min_degree}")
print(f"Avg Degree = {avg_degree}")
# plot_directed_graph(true_ess_graph)
########### ########### Generate data ############################################  
start_data=time.time()
print("Data generation started at ", start_data)
data = continuous_data_generation_arbitrary_lags(nSamples_list[-1],B=B1,A=A1, noise_pow=noise_pow1,nlags=nLags)
print("Time taken to generate data = ", time.time()-start_data)

# np.savetxt("dataSet_20Nodes_phase_violation.txt", data)
        
plt.plot(np.arange(0,data.shape[0]), data[:,0])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,1])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,2])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,3])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,4])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,5])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,6])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,7])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,8])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,9])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,10])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,11])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,12])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,13])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,14])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,15])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,16])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,17])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,18])
plt.show()
plt.plot(np.arange(0,data.shape[0]), data[:,19])
plt.show()

            
            
###########################################################################################################################
# Algorithm 

thresh_freq = np.arange(0.01,0.05,0.001)#np.array([0.04])#
error_freq_PC = np.zeros(len(thresh_freq))
TPR = np.zeros(len(thresh_freq))
FPR = np.zeros(len(thresh_freq))
computation_time = np.zeros(len(thresh_freq))




for ind_thresh in range(len(thresh_freq)):
    start_time = time.time()
    est_ess_graph = Run_Wiener_Phase_avg(data,nFFT,thresh=thresh_freq[ind_thresh])
    end_time = time.time()
    error_freq_PC[ind_thresh] = sum(sum(est_ess_graph!=true_ess_graph))/sum(sum(true_ess_graph))
    computation_time[ind_thresh] = end_time - start_time
    [TP, FP, TN, FN] = get_confusion_matrix(true_ess_graph, est_ess_graph)
    TPR[ind_thresh] = TP/(TP+FN)
    FPR[ind_thresh] = FP/(TN+FP)
    print("----------------------------------------______")
    print('Threshold=', thresh_freq[ind_thresh])
    print(f"TPR = {TPR[ind_thresh]}, FPR = {FPR[ind_thresh]}, CS = {TPR[ind_thresh]-FPR[ind_thresh]}")
    print(f"Run time = {computation_time[ind_thresh]}")

CS = TPR - FPR
plt.plot(thresh_freq,error_freq_PC,'g-.',label="error rate (time)",linewidth=3)
plt.legend()
plt.show()
plt.plot(thresh_freq,TPR,'r-.',label="TPR",linewidth=3)
plt.legend()
plt.show()
plt.plot(thresh_freq,FPR,'b-.',label="FPR",linewidth=3)
plt.legend()
plt.show()
plt.plot(thresh_freq,CS,'y-.',label="CS",linewidth=3)
plt.legend()
plt.show()
plt.plot(thresh_freq,computation_time,'m-.',label="Computation time",linewidth=3)
plt.legend()
plt.show()
    
    
  