"""
This script runs the Wiener-phase algorithm on the river-runoff data and reports the results
"""


import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import time


# %% Graph plotting functions

def plot_topology(adjacency_matrix):
    nNodes=adjacency_matrix.shape[0]
    G=nx.Graph()
    G.add_nodes_from(['V'+str(ind) for ind in range(1,nNodes+1)])
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            if adjacency_matrix[ind1,ind2]:
                G.add_edge('V'+str(ind1+1),'V'+str(ind2+1))
    # explicitly set positions
    # pos = {'V'+str(1): (-1, 10), 'V'+str(2): (0, 0),'V'+str(3): (4.5, 7), 'V'+str(4): (9, 8), 'V'+str(5): (12, 1), 'V'+str(6): (16, 3)}
    pos = {'V'+str(1): (-2, 0), 'V'+str(2): (0, 2),'V'+str(3): (1, -0.5), 'V'+str(4): (2, 1), 'V'+str(5): (3, -1), 'V'+str(6): (5, 0),
            'V'+str(7): (5, 2.25), 'V'+str(8): (5, 4.5),'V'+str(9): (7, -3), 'V'+str(10): (9, 0.75), 'V'+str(11): (7, -0.5), 'V'+str(12): (4, -4)}
    
           
    options = {
    "font_size": 12,
    "node_size": 1000,
    "node_color": "white",
    "edgecolors": "black",
    "linewidths": 2,
    "width": 2
    }
    nx.draw_networkx(G,pos, **options)
    # Set margins for the axes so that nodes aren't clipped
    ax = plt.gca()
    ax.margins(0.20)
    plt.axis("off")
    plt.show(block=False)
    plt.pause(0.001)
    
    
    
    
    
    
    
def plot_directed_graph(adjacency_matrix):
    nNodes=adjacency_matrix.shape[0]
    G=nx.DiGraph()
    G.add_nodes_from(['V'+str(ind) for ind in range(1,nNodes+1)])
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            if adjacency_matrix[ind1,ind2]:
                G.add_edge('V'+str(ind2+1),'V'+str(ind1+1))
    # explicitly set positions
    # pos = {'V'+str(1): (-1, 10), 'V'+str(2): (0, 0),'V'+str(3): (5, 7), 'V'+str(4): (8, 8), 'V'+str(5): (12, 1), 'V'+str(6): (16, 3)}
    pos = {'V'+str(1): (-2, 0), 'V'+str(2): (0, 2),'V'+str(3): (1, -1), 'V'+str(4): (2, 1), 'V'+str(5): (3, -1), 'V'+str(6): (5, 0),
            'V'+str(7): (5, 2.25), 'V'+str(8): (5, 4.5),'V'+str(9): (7, -3), 'V'+str(10): (9, 0.75), 'V'+str(11): (7, -0.5), 'V'+str(12): (4, -4)}
    
           
    options = {
    "font_size": 12,
    "node_size": 1000,
    "node_color": "white",
    "edgecolors": "black",
    "linewidths": 2,
    "width": 2,
    }

    nx.draw_networkx(G, pos, **options)
    # Set margins for the axes so that nodes aren't clipped
    ax = plt.gca()
    ax.margins(0.20)
    plt.axis("off")
    plt.show(block=False)
    plt.pause(0.001)




# Compute the FFT of the data
# Inputs:  (1) data: time-series data
#          (2) nfft: numer of FFT points
# Output:  (1) y: the FFT of the provided data samples
def compute_fft(data,nfft):
    nSamples,nNodes=data.shape
    nTrajectories=np.int32(nSamples/nfft) 
    y=np.zeros((nTrajectories,nNodes,nfft),dtype=complex)
    for ind in range(nTrajectories-1): # discard final few residual samples
        x=data[ind*(nfft):(ind+1)*nfft,:]
        X=np.fft.fft(x,axis=0) #/np.sqrt(nfft)
        y[ind,:,:]=np.transpose(X) #X is (nfft,nNodes)
    return y



# Compute least squates solution: Project X on to Z, Z=a0x_0+a1x1+...am xm, where m is the number of columns of X
def Least_Squares_Solution(X,Z):
    X_star = X.conj().T
    XX=np.matmul(X_star, X)
    if np.isscalar(XX):
        return np.matmul(np.divide( X_star,XX), Z)
    else:
        return np.matmul(np.matmul(np.linalg.inv(np.matmul(X_star, X)), X_star), Z)



def Wiener_Proj(X,Z):
    # Project X on to Z, Z=a0x_0+a1x1+...am xm, 
    # where m is the number of columns of X
    X_star = X.conj().T
    XX=np.matmul(X_star, X)
    if np.isscalar(XX):
        return np.matmul(np.divide( X_star,XX), Z)
    else:
        return np.matmul(np.matmul(np.linalg.inv(np.matmul(X_star, X)), X_star), Z)



def compute_Wiener_coeffs(nNodes,data):
    # Returns W and W_mag
    # W: nNodes X nSamples X nNodes
    # W_mag: returns h-infinity norm over frequencies
    nSamples=data.shape[2]
    W=np.zeros([nNodes,nSamples,nNodes],dtype='complex')
    W_mag=np.zeros([nNodes,nNodes])
    for proj_ind in range(nNodes):
        W1=np.zeros([nSamples,nNodes],dtype='complex')
        for freq_index in range(nSamples):
            fft_data=data[:,:,freq_index]
            Z=fft_data[:,proj_ind]
            X_bar=np.delete(fft_data,proj_ind,1)
            W1[freq_index,:]=np.insert(Wiener_Proj(X_bar,Z),proj_ind,0)
        W[proj_ind,:,:]=W1
        W_mag[proj_ind,:]=np.amax(abs(W1),axis=0)
    return W,W_mag



# Compute the partial Wiener coefficients using FFT based approach
# Inputs:  (1) data: time-series data.
#          (2) i: First node.
#          (3) j: Second node.
#          (4) z: Conditioning set of nodes.
# Outputs: (1) W: The estimated Wiener filters
#          (2) W_mag: Maximum magnitude of the Wiener filters taken over the frequencies
def compute_partial_Wiener_coeffs_FFT(data,i,j,z=[]):
    c=np.append(j,np.int32(z))
    indexes = np.unique(c, return_index=True)[1]
    c=[c[index] for index in sorted(indexes)]
    n1=len(c)
    nFFT=data.shape[2]
    W=np.zeros([nFFT,n1],dtype='complex')
    W_mag=np.zeros([n1])
    for freq_index in range(nFFT):
        fft_data=data[:,:,freq_index]
        X=fft_data[:,i]
        X_bar=fft_data[:,c]
        W[freq_index,:]=Least_Squares_Solution(X_bar,X)
    W_mag=np.amax(abs(W),axis=0)
    return W,W_mag



def compute_partial_Wiener_coeffs_avg(data,i,j,z=[],freq_choice=[]):
    # Returns W and W_mag
    # W: nNodes X n_FFT X nNodes
    # W_mag: returns h-infinity norm over frequencies
    c=np.append(j,np.int32(z))
    # c=np.unique(c)
    indexes = np.unique(c, return_index=True)[1]
    c=[c[index] for index in sorted(indexes)]
    n1=len(c)
    n_FFT=data.shape[2]
    W=np.zeros([n_FFT,n1],dtype='complex')
    # W_mag=np.zeros([n1])
    W_mag_avg = np.zeros([n1],dtype=float)
    for freq_index in freq_choice:
        fft_data=data[:,:,freq_index]
        # print(fft_data.shape)
        # print(freq_index)
        X=fft_data[:,i]
        X_bar=fft_data[:,c]
        # print(Z.shape,X_bar.shape)
        # print(X_bar[:3,:])
        W[freq_index,:]=Wiener_Proj(X_bar,X)
        W_mag_avg = W_mag_avg + np.abs(W[freq_index,:])
    W_mag_avg = np.divide(W_mag_avg,len(freq_choice))
    # W_mag=np.amax(abs(W),axis=0)
    return W,W_mag_avg



def Run_Wiener_Phase_avg(data,nFFT,thresh=0.1):
    nNodes=data.shape[1]
    data_FFT=compute_fft(data,nfft=nFFT)
    W,Wmag=compute_Wiener_coeffs(nNodes, data_FFT)
    
    
    
    # Moral Graph estimation using absolute value of average magnitude of Wiener projection over frequencies
    W_mag_avg = np.zeros((nNodes,nNodes),dtype=float)

    freq_choice = range(nFFT)
    for i in freq_choice:
        W_mag_avg = W_mag_avg + np.abs(W[:,i,:])
        
    W_mag_avg = np.divide(W_mag_avg,len(freq_choice))
    W_mag_avg = np.divide(W_mag_avg+np.transpose(W_mag_avg),2)
    moral_graph_adj=W_mag_avg>thresh
    
    # Skeleton estimation using h-inf of imaginary part of Wiener projection coefficients
    W_imag=W.imag
    W_imag_avg = np.zeros((nNodes,nNodes),dtype=float)
    freq_choice = range(nFFT)
    for i in freq_choice:
        W_imag_avg = W_imag_avg + np.abs(W_imag[:,i,:])
        
    W_imag_avg = np.divide(W_imag_avg,len(freq_choice))
    W_imag_avg = np.divide(W_imag_avg+np.transpose(W_imag_avg),2)
    skeleton_adj=W_imag_avg>thresh
    
    # Essential graph estimation
    spurious_adj=abs(moral_graph_adj^skeleton_adj)
    C={}
    collider_graph=np.zeros_like(skeleton_adj)
    collider_graph[:]=skeleton_adj[:]
    collider_set = []
    for i in range(nNodes):
        for j in range(nNodes):
            if i>j and spurious_adj[i,j]==1:
                potential_colider=[]
                for k in range(nNodes):
                    if skeleton_adj[i,k]==1 and skeleton_adj[j,k]==1:
                        potential_colider.append(k)
                C[i,j]=potential_colider
                if len(C[i,j])==1:
                    collider_graph[i,C[i,j]]=False
                    collider_graph[j,C[i,j]]=False
                    collider_set.append(C[i,j][0])
                else:
                    for l in C[i,j]:
                        W_ij_c,W_ij_c_mag=compute_partial_Wiener_coeffs_avg(data_FFT, i, j,l,freq_choice = range(nFFT))
                        if W_ij_c_mag[0]>thresh:
                            collider_graph[i,l]=False
                            collider_graph[j,l]=False
                            collider_set.append(l)
                            
                            
    # % Orienting other edges if possible
    final_graph = np.copy(collider_graph)
    for i in range(nNodes):
        for j in range(nNodes):
            if i>j and collider_graph[i,j]==1 and collider_graph[j,i]==1:
                if i in collider_set:
                    final_graph[i,j] = 0
                if j in collider_set:
                    final_graph[j,i] = 0
                        
    
    return final_graph



def data_processing(original_data_raw,nSamples,n_FFT,nNodes):
    y = np.zeros((nSamples,nNodes,n_FFT),dtype=complex)
    temp_ind = 0
    for ind_Traj in range(nSamples):
        x = original_data_raw[temp_ind:temp_ind+n_FFT:1,:]
        temp_ind = temp_ind + n_FFT
        X = np.fft.fft(x,axis=0)
        y[ind_Traj,:,:] = np.transpose(X)
    return y


def get_confusion_matrix(true_graph,estimated_graph):
    num_nodes = true_graph.shape[0]
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i!=j:
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 1:
                    TP += 1
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 0:
                    FN += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 1:
                    FP += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 0:
                    TN += 1

    return TP, FP, TN, FN

###########################################################################################################################
# Main
nNodes = 12
num_nodes = nNodes
nFFT = 4



data = np.loadtxt(r'.\river-runoff\river-runoff_N-12_T-4600\river-runoff_N-12_T-4600_0001.txt', delimiter=' ')

thresh_freq = np.array([0.108])
error_freq_PC = np.zeros(len(thresh_freq))
TPR = np.zeros(len(thresh_freq))
FPR = np.zeros(len(thresh_freq))


true_ess_graph = np.array([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

ind_thresh = 0
start_time = time.time()
est_ess_graph = Run_Wiener_Phase_avg(data,nFFT,thresh=thresh_freq[ind_thresh])
end_time = time.time()
[TP, FP, TN, FN] = get_confusion_matrix(true_ess_graph, est_ess_graph)
TPR[ind_thresh] = TP/(TP+FN)
FPR[ind_thresh] = FP/(TN+FP)
elapsed_time = end_time - start_time
print("----------------------------------------______")
print('Threshold=', thresh_freq[ind_thresh])
print(f"TPR = {TPR[ind_thresh]}, FPR = {FPR[ind_thresh]}")
print(f"Elapsed algorithm run time = {elapsed_time}")

CS = TPR - FPR













