"""
This script runs the FFT WPC algorithm in the river-runoff dataset and reports the results
"""



import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from itertools import combinations
import time



# %% Graph plotting functions

def plot_topology(adjacency_matrix):
    nNodes=adjacency_matrix.shape[0]
    G=nx.Graph()
    G.add_nodes_from(['V'+str(ind) for ind in range(1,nNodes+1)])
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            if adjacency_matrix[ind1,ind2]:
                G.add_edge('V'+str(ind1+1),'V'+str(ind2+1))
    # explicitly set positions
    # pos = {'V'+str(1): (-1, 10), 'V'+str(2): (0, 0),'V'+str(3): (4.5, 7), 'V'+str(4): (9, 8), 'V'+str(5): (12, 1), 'V'+str(6): (16, 3)}
    pos = {'V'+str(1): (-2, 0), 'V'+str(2): (0, 2),'V'+str(3): (1, -0.5), 'V'+str(4): (2, 1), 'V'+str(5): (3, -1), 'V'+str(6): (5, 0),
            'V'+str(7): (5, 2.25), 'V'+str(8): (5, 4.5),'V'+str(9): (7, -3), 'V'+str(10): (9, 0.75), 'V'+str(11): (7, -0.5), 'V'+str(12): (4, -4)}
    
           
    options = {
    "font_size": 12,
    "node_size": 1000,
    "node_color": "white",
    "edgecolors": "black",
    "linewidths": 2,
    "width": 2
    }
    nx.draw_networkx(G,pos, **options)
    # Set margins for the axes so that nodes aren't clipped
    ax = plt.gca()
    ax.margins(0.20)
    plt.axis("off")
    plt.show(block=False)
    plt.pause(0.001)
    
    
    
    
    
    
    
def plot_directed_graph(adjacency_matrix):
    nNodes=adjacency_matrix.shape[0]
    G=nx.DiGraph()
    G.add_nodes_from(['V'+str(ind) for ind in range(1,nNodes+1)])
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            if adjacency_matrix[ind1,ind2]:
                G.add_edge('V'+str(ind2+1),'V'+str(ind1+1))
    # explicitly set positions
    # pos = {'V'+str(1): (-1, 10), 'V'+str(2): (0, 0),'V'+str(3): (5, 7), 'V'+str(4): (8, 8), 'V'+str(5): (12, 1), 'V'+str(6): (16, 3)}
    pos = {'V'+str(1): (-2, 0), 'V'+str(2): (0, 2),'V'+str(3): (1, -1), 'V'+str(4): (2, 1), 'V'+str(5): (3, -1), 'V'+str(6): (5, 0),
            'V'+str(7): (5, 2.25), 'V'+str(8): (5, 4.5),'V'+str(9): (7, -3), 'V'+str(10): (9, 0.75), 'V'+str(11): (7, -0.5), 'V'+str(12): (4, -4)}
    
           
    options = {
    "font_size": 12,
    "node_size": 1000,
    "node_color": "white",
    "edgecolors": "black",
    "linewidths": 2,
    "width": 2,
    }

    nx.draw_networkx(G, pos, **options)
    # Set margins for the axes so that nodes aren't clipped
    ax = plt.gca()
    ax.margins(0.20)
    plt.axis("off")
    plt.show(block=False)
    plt.pause(0.001)

# %% Frequency domain projection functions

def Wiener_Proj(X,Z):
    # Project X on to Z, Z=a0x_0+a1x1+...am xm, 
    # where m is the number of columns of X
    X_star = X.conj().T
    XX=np.matmul(X_star, X)
    if np.isscalar(XX):
        return np.matmul(np.divide( X_star,XX), Z)
    else:
        return np.matmul(np.matmul(np.linalg.inv(np.matmul(X_star, X)), X_star), Z)






def compute_Wiener_coeffs(nNodes,data):
    # Returns W and W_mag
    # W: nNodes X nSamples X nNodes
    # W_mag: returns h-infinity norm over frequencies
    nSamples=data.shape[2]
    W=np.zeros([nNodes,nSamples,nNodes],dtype='complex')
    W_mag=np.zeros([nNodes,nNodes])
    for proj_ind in range(nNodes):
        W1=np.zeros([nSamples,nNodes],dtype='complex')
        for freq_index in range(nSamples):
            fft_data=data[:,:,freq_index]
            Z=fft_data[:,proj_ind]
            X_bar=np.delete(fft_data,proj_ind,1)
            W1[freq_index,:]=np.insert(Wiener_Proj(X_bar,Z),proj_ind,0)
        W[proj_ind,:,:]=W1
        W_mag[proj_ind,:]=np.amax(abs(W1),axis=0)
    return W,W_mag




def compute_partial_Wiener_coeffs_hinf(data,i,j,z=[],freq_choice=[]):
    # Returns W and W_mag
    # W: nNodes X n_FFT X nNodes
    # W_mag: returns h-infinity norm over frequencies
    c=np.append(j,np.int32(z))
    indexes = np.unique(c, return_index=True)[1]
    c=[c[index] for index in sorted(indexes)]
    n1=len(c)
    n_FFT=data.shape[2]
    W=np.zeros([n_FFT,n1],dtype='complex')
    W_mag=np.zeros([n1])
    for freq_index in freq_choice:#range(n_FFT):
        fft_data=data[:,:,freq_index]
        X=fft_data[:,i]
        X_bar=fft_data[:,c]
        W[freq_index,:]=Wiener_Proj(X_bar,X)
    W_mag=np.amax(abs(W),axis=0)
    return W,W_mag

def compute_partial_Wiener_coeffs_avg(data,i,j,z=[],freq_choice=[]):
    # Returns W and W_mag
    # W: nNodes X n_FFT X nNodes
    # W_mag: returns h-infinity norm over frequencies
    c=np.append(j,np.int32(z))
    # c=np.unique(c)
    indexes = np.unique(c, return_index=True)[1]
    c=[c[index] for index in sorted(indexes)]
    n1=len(c)
    n_FFT=data.shape[2]
    W=np.zeros([n_FFT,n1],dtype='complex')
    # W_mag=np.zeros([n1])
    W_mag_avg = np.zeros([n1],dtype=float)
    for freq_index in freq_choice:
        fft_data=data[:,:,freq_index]
        # print(fft_data.shape)
        # print(freq_index)
        X=fft_data[:,i]
        X_bar=fft_data[:,c]
        # print(Z.shape,X_bar.shape)
        # print(X_bar[:3,:])
        W[freq_index,:]=Wiener_Proj(X_bar,X)
        W_mag_avg = W_mag_avg + np.abs(W[freq_index,:])
    W_mag_avg = np.divide(W_mag_avg,len(freq_choice))
    # W_mag=np.amax(abs(W),axis=0)
    return W,W_mag_avg

# Compute the FFT of the data
# Inputs:  (1) data: time-series data
#          (2) nfft: numer of FFT points
# Output:  (1) y: the FFT of the provided data samples
def compute_fft(data,nfft):
    nSamples,nNodes=data.shape
    nTrajectories=np.int32(nSamples/nfft) 
    y=np.zeros((nTrajectories,nNodes,nfft),dtype=complex)
    for ind in range(nTrajectories-1): # discard final few residual samples
        x=data[ind*(nfft):(ind+1)*nfft,:]
        X=np.fft.fft(x,axis=0) #/np.sqrt(nfft)
        y[ind,:,:]=np.transpose(X) #X is (nfft,nNodes)
    return y




def Run_PC_Wiener_test_FFT_avg(data,nFFT,threshd):
    # print("Run_PC_Wiener_test_FFT_avg")
    nNodes=data.shape[1]
    C={}
    data_FFT = compute_fft(data,nfft=nFFT)
    freq_choice = range(nFFT)
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            if ind1>ind2:
                # print('Chocsen node pair: [',ind1+1,',',ind2+1,']')
                flag=False
                # the complement set through which we condition
                D=[i for i in range(nNodes) if i != ind1 and i !=ind2]
                # print("D ",np.array(D)+1)
                # iterate over the combinations of various size in the increasing order
                for ind_cond in range(nNodes-1):
                    if flag==True:
                        break
                    combination_set=[i for i in combinations(D,ind_cond)]
                    # print("combination"+str(ind_cond)+" =",combination_set)
                    for c in combination_set:
                        # print("\n combination set: ",np.array(c)+1)
                        # # The partial Wiener coefficients does not seem symmetric wrt the projection. So, taking the average of two.
                        W_i_jZ1,W_mag_i_jZ1=compute_partial_Wiener_coeffs_avg(data_FFT,ind1,ind2,c,freq_choice)
                        
                        W_mag_i_jZ=W_mag_i_jZ1
                        
                        # print("Projection coefficient of X"+str(ind1+1)+" on "+str(np.append(ind2,c)+1)+"=",W_mag_i_jZ)
                        if W_mag_i_jZ[0]<threshd:
                            # print("Separation set of ( X"+str(ind1+1)+", X"+str(ind2+1),") is : ",np.array(c)+1)
                            # c=[l+1 for l in c]
                            C[ind1,ind2]=c
                            flag=True
                            break
                            c=np.array([ind1])
                        # if np.shape(c)[0]==3:#nNodes-2:
                        #     print("No separation set for ( X"+str(ind1+1)+", X"+str(ind2+1),")")
                        C[ind1,ind2]=None
            
    return C


################################################################################################################################
def data_processing(original_data_raw,nSamples,n_FFT,nNodes):
    y = np.zeros((nSamples,nNodes,n_FFT),dtype=complex)
    temp_ind = 0
    for ind_Traj in range(nSamples):
        x = original_data_raw[temp_ind:temp_ind+n_FFT:1,:]
        temp_ind = temp_ind + n_FFT
        X = np.fft.fft(x,axis=0)
        y[ind_Traj,:,:] = np.transpose(X)
    return y


def normalize_data(data):
    col_means = np.mean(data,axis=0)
    col_stds = np.std(data,axis=0)
    normalizedData = (data - col_means)/col_stds
    return normalizedData

# %% Essential graph estimation function

def estimate_essential_graph(Top_est,C):
    colliders = list()
    Est_adj = np.copy(Top_est)
    # # Estimate MEG from skeleton and d-separating set, C
    for ind1 in range(num_nodes):
        for ind2 in range(num_nodes):   
            for ind3 in range(ind2+1,num_nodes):
                # check if ind3 is present in the d-separating set of (ind1,ind2) 
                # is present in skeleton when ind3--ind1 and ind3--ind2 
                # are present in the skeleton. Not present means ind3 is a collider
                # If not present, then ind3 is a collider
                if (Top_est[ind1,ind2]==1) and (Top_est[ind1,ind3]==1) and ind1!=ind3 and ind2!=ind1:
                    print(ind1,ind2,ind3,C[ind1,ind2])
                    if C[(ind3,ind2)]!=None: #None means no d-separating set
                        if (ind1 not in C[(ind3,ind2)]):
                            print(ind1, " is a collider")
                            colliders.append(ind1)
                            Est_adj[ind2,ind1] = 0
                            Est_adj[ind3,ind1] = 0
                    # else:
                    #     print(ind1, " :: is a collider")
                    #     Est_adj[ind2,ind1]=0
                    #     Est_adj[ind3,ind1]=0
                    #     colliders.append(ind1)
                    # pass
    colliders=list(set(colliders))
    return Est_adj,colliders

def estimate_essential_graph_new(Top_est,C):
    
    #% Collider Identification
    collider_graph = np.copy(Top_est)
    collider_set = []
    for i in range(nNodes):
        for j in range(nNodes):
            for k in range(nNodes):
                if i>j and j!=k and i!=k and Top_est[i,k]==1 and Top_est[j,k]==1 and Top_est[i,j]!=1 and (C[(i,j)]!=None) and (k not in C[(i,j)]):
                    collider_graph[i,k] = 0
                    collider_graph[j,k] = 0
                    collider_set.append(k)


    # % Orienting other edges if possible
    final_graph = np.copy(collider_graph)
    for i in range(nNodes):
        for j in range(nNodes):
            if i>j and collider_graph[i,j]==1 and collider_graph[j,i]==1:
                if i in collider_set:
                    final_graph[i,j] = 0
                if j in collider_set:
                    final_graph[j,i] = 0

    
    return final_graph,collider_graph


def get_confusion_matrix(true_graph,estimated_graph):
    num_nodes = true_graph.shape[0]
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i!=j:
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 1:
                    TP += 1
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 0:
                    FN += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 1:
                    FP += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 0:
                    TN += 1

    return TP, FP, TN, FN

############################################################################################################################################################  

############################################################################################################################################################
# %% Initialization, data loading and data processing

nNodes = 12
num_nodes = nNodes
n_FFT = 4



data = np.loadtxt(r'.\river-runoff\river-runoff_N-12_T-4600\river-runoff_N-12_T-4600_0001.txt', delimiter=' ')




# %% PC test

thresh_freq = np.array([0.404])
error_freq_PC = np.zeros(len(thresh_freq))
TPR = np.zeros(len(thresh_freq))
FPR = np.zeros(len(thresh_freq))
TPR_paper = np.zeros(len(thresh_freq))
FPR_paper = np.zeros(len(thresh_freq))
true_ess_graph = np.array([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

ind_thresh = 0
start_time = time.time()
C_freq = Run_PC_Wiener_test_FFT_avg(data, n_FFT, thresh_freq[ind_thresh])
Top_est_half = np.zeros([num_nodes,num_nodes])
for c in C_freq:
    if C_freq[c] == None: Top_est_half[c]=1
Top_est = (Top_est_half + Top_est_half.T)
est_essential_graph_freq,est_colliders_freq = estimate_essential_graph_new(Top_est,C_freq)
end_time = time.time()
elapsed_time = end_time - start_time

error_freq_PC[ind_thresh] = sum(sum(est_essential_graph_freq!=true_ess_graph))/sum(sum(true_ess_graph))

[TP, FP, TN, FN] = get_confusion_matrix(true_ess_graph, est_essential_graph_freq)
TPR[ind_thresh] = TP/(TP+FN)
FPR[ind_thresh] = FP/(TN+FP)

print("----------------------------------------______")
print('Threshold=', thresh_freq[ind_thresh])
print(f"TPR = {TPR[ind_thresh]}, FPR = {FPR[ind_thresh]}")
print(f"Elapsed algorithm run time = {elapsed_time}")

CS = TPR - FPR



