"""
This script implements FFT WPC algorithm on a dataset where strict temporal caulality is violated
this is used to compare it with Granger causality.
"""


import numpy as np
from itertools import combinations
import time



# %% Frequency domain projection functions

def Wiener_Proj(X,Z):
    # Project X on to Z, Z=a0x_0+a1x1+...am xm, 
    # where m is the number of columns of X
    X_star = X.conj().T
    XX=np.matmul(X_star, X)
    if np.isscalar(XX):
        return np.matmul(np.divide( X_star,XX), Z)
    else:
        return np.matmul(np.matmul(np.linalg.inv(np.matmul(X_star, X)), X_star), Z)






def compute_Wiener_coeffs(nNodes,data):
    # Returns W and W_mag
    # W: nNodes X nSamples X nNodes
    # W_mag: returns h-infinity norm over frequencies
    nSamples=data.shape[2]
    W=np.zeros([nNodes,nSamples,nNodes],dtype='complex')
    W_mag=np.zeros([nNodes,nNodes])
    for proj_ind in range(nNodes):
        W1=np.zeros([nSamples,nNodes],dtype='complex')
        for freq_index in range(nSamples):
            fft_data=data[:,:,freq_index]
            Z=fft_data[:,proj_ind]
            X_bar=np.delete(fft_data,proj_ind,1)
            W1[freq_index,:]=np.insert(Wiener_Proj(X_bar,Z),proj_ind,0)
        W[proj_ind,:,:]=W1
        W_mag[proj_ind,:]=np.amax(abs(W1),axis=0)
    return W,W_mag




def compute_partial_Wiener_coeffs_hinf(data,i,j,z=[],freq_choice=[]):
    # Returns W and W_mag
    # W: nNodes X n_FFT X nNodes
    # W_mag: returns h-infinity norm over frequencies
    c=np.append(j,np.int32(z))
    indexes = np.unique(c, return_index=True)[1]
    c=[c[index] for index in sorted(indexes)]
    n1=len(c)
    n_FFT=data.shape[2]
    W=np.zeros([n_FFT,n1],dtype='complex')
    W_mag=np.zeros([n1])
    for freq_index in freq_choice:#range(n_FFT):
        fft_data=data[:,:,freq_index]
        X=fft_data[:,i]
        X_bar=fft_data[:,c]
        W[freq_index,:]=Wiener_Proj(X_bar,X)
    W_mag=np.amax(abs(W),axis=0)
    return W,W_mag

def compute_partial_Wiener_coeffs_avg(data,i,j,z=[],freq_choice=[]):
    # Returns W and W_mag
    # W: nNodes X n_FFT X nNodes
    # W_mag: returns h-infinity norm over frequencies
    c=np.append(j,np.int32(z))
    # c=np.unique(c)
    indexes = np.unique(c, return_index=True)[1]
    c=[c[index] for index in sorted(indexes)]
    n1=len(c)
    n_FFT=data.shape[2]
    W=np.zeros([n_FFT,n1],dtype='complex')
    # W_mag=np.zeros([n1])
    W_mag_avg = np.zeros([n1],dtype=float)
    for freq_index in freq_choice:
        fft_data=data[:,:,freq_index]
        # print(fft_data.shape)
        # print(freq_index)
        X=fft_data[:,i]
        X_bar=fft_data[:,c]
        # print(Z.shape,X_bar.shape)
        # print(X_bar[:3,:])
        W[freq_index,:]=Wiener_Proj(X_bar,X)
        W_mag_avg = W_mag_avg + np.abs(W[freq_index,:])
    W_mag_avg = np.divide(W_mag_avg,len(freq_choice))
    # W_mag=np.amax(abs(W),axis=0)
    return W,W_mag_avg

# Compute the FFT of the data
# Inputs:  (1) data: time-series data
#          (2) nfft: numer of FFT points
# Output:  (1) y: the FFT of the provided data samples
def compute_fft(data,nfft):
    nSamples,nNodes=data.shape
    nTrajectories=np.int32(nSamples/nfft) 
    y=np.zeros((nTrajectories,nNodes,nfft),dtype=complex)
    for ind in range(nTrajectories-1): # discard final few residual samples
        x=data[ind*(nfft):(ind+1)*nfft,:]
        X=np.fft.fft(x,axis=0) #/np.sqrt(nfft)
        y[ind,:,:]=np.transpose(X) #X is (nfft,nNodes)
    return y




def Run_PC_Wiener_test_FFT_avg(data,nFFT,threshd,freq_choice):
    # print("Run_PC_Wiener_test_FFT_avg")
    nNodes=data.shape[1]
    C={}
    data_FFT = compute_fft(data,nfft=nFFT)
    # freq_choice = range(nFFT)
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            if ind1>ind2:
                # print('Chocsen node pair: [',ind1+1,',',ind2+1,']')
                flag=False
                # the complement set through which we condition
                D=[i for i in range(nNodes) if i != ind1 and i !=ind2]
                # print("D ",np.array(D)+1)
                # iterate over the combinations of various size in the increasing order
                for ind_cond in range(nNodes-1):
                    if flag==True:
                        break
                    combination_set=[i for i in combinations(D,ind_cond)]
                    # print("combination"+str(ind_cond)+" =",combination_set)
                    for c in combination_set:
                        # print("\n combination set: ",np.array(c)+1)
                        # # The partial Wiener coefficients does not seem symmetric wrt the projection. So, taking the average of two.
                        W_i_jZ1,W_mag_i_jZ1=compute_partial_Wiener_coeffs_avg(data_FFT,ind1,ind2,c,freq_choice)
                        
                        W_mag_i_jZ=W_mag_i_jZ1
                        
                        # print("Projection coefficient of X"+str(ind1+1)+" on "+str(np.append(ind2,c)+1)+"=",W_mag_i_jZ)
                        if W_mag_i_jZ[0]<threshd:
                            # print("Separation set of ( X"+str(ind1+1)+", X"+str(ind2+1),") is : ",np.array(c)+1)
                            # c=[l+1 for l in c]
                            C[ind1,ind2]=c
                            flag=True
                            break
                            c=np.array([ind1])
                        # if np.shape(c)[0]==3:#nNodes-2:
                        #     print("No separation set for ( X"+str(ind1+1)+", X"+str(ind2+1),")")
                        C[ind1,ind2]=None
            
    return C



################################################################################################################################
def data_processing(original_data_raw,nSamples,n_FFT,nNodes):
    y = np.zeros((nSamples,nNodes,n_FFT),dtype=complex)
    temp_ind = 0
    for ind_Traj in range(nSamples):
        x = original_data_raw[temp_ind:temp_ind+n_FFT:1,:]
        temp_ind = temp_ind + n_FFT
        X = np.fft.fft(x,axis=0)
        y[ind_Traj,:,:] = np.transpose(X)
    return y


def normalize_data(data):
    col_means = np.mean(data,axis=0)
    col_stds = np.std(data,axis=0)
    normalizedData = (data - col_means)/col_stds
    return normalizedData

# %% Essential graph estimation function

def estimate_essential_graph(Top_est,C):
    colliders = list()
    Est_adj = np.copy(Top_est)
    # # Estimate MEG from skeleton and d-separating set, C
    for ind1 in range(num_nodes):
        for ind2 in range(num_nodes):   
            for ind3 in range(ind2+1,num_nodes):
                # check if ind3 is present in the d-separating set of (ind1,ind2) 
                # is present in skeleton when ind3--ind1 and ind3--ind2 
                # are present in the skeleton. Not present means ind3 is a collider
                # If not present, then ind3 is a collider
                if (Top_est[ind1,ind2]==1) and (Top_est[ind1,ind3]==1) and ind1!=ind3 and ind2!=ind1:
                    print(ind1,ind2,ind3,C[ind1,ind2])
                    if C[(ind3,ind2)]!=None: #None means no d-separating set
                        if (ind1 not in C[(ind3,ind2)]):
                            print(ind1, " is a collider")
                            colliders.append(ind1)
                            Est_adj[ind2,ind1] = 0
                            Est_adj[ind3,ind1] = 0
                    # else:
                    #     print(ind1, " :: is a collider")
                    #     Est_adj[ind2,ind1]=0
                    #     Est_adj[ind3,ind1]=0
                    #     colliders.append(ind1)
                    # pass
    colliders=list(set(colliders))
    return Est_adj,colliders

def estimate_essential_graph_new(Top_est,C):
    
    #% Collider Identification
    collider_graph = np.copy(Top_est)
    collider_set = []
    for i in range(nNodes):
        for j in range(nNodes):
            for k in range(nNodes):
                if i>j and j!=k and i!=k and Top_est[i,k]==1 and Top_est[j,k]==1 and Top_est[i,j]!=1 and (C[(i,j)]!=None) and (k not in C[(i,j)]):
                    collider_graph[i,k] = 0
                    collider_graph[j,k] = 0
                    collider_set.append(k)


    # % Orienting other edges if possible
    final_graph = np.copy(collider_graph)
    for i in range(nNodes):
        for j in range(nNodes):
            if i>j and collider_graph[i,j]==1 and collider_graph[j,i]==1:
                if i in collider_set:
                    final_graph[i,j] = 0
                if j in collider_set:
                    final_graph[j,i] = 0

    
    return final_graph,collider_graph


def get_confusion_matrix(true_graph,estimated_graph):
    num_nodes = true_graph.shape[0]
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i!=j:
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 1:
                    TP += 1
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 0:
                    FN += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 1:
                    FP += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 0:
                    TN += 1

    return TP, FP, TN, FN


# Find moral graph and skeleton for a given graph given its adjacency matrix
# Input:  (1) adjacency_matrix: Adjacency matrix of directed graph
# Output: (1) MG: Adjacency matrix of the estimated moral graph
#         (2) Skeleton: Adjacency matrix of the estimated skeleton
def find_moral_graph_and_skeleton_from_adjacency_matrix(Adj):
    nNodes=Adj.shape[0]
    Skeleton=(Adj+Adj.T)
    MG=np.zeros([nNodes,nNodes],dtype=bool)
    MG=np.copy(Skeleton)
    for ind1 in range(nNodes):
        for ind2 in range(nNodes):
            for ind3 in range(nNodes):
                if (Adj[ind1,ind2]==1) and ((Adj[ind1,ind3]==1)) and ind2!=ind3:
                    MG[ind2,ind3]=1
                    MG[ind3,ind2]=1
    return MG,Skeleton

def find_true_Essential_graph(Skeleton, Adj):
    collider_graph = np.copy(Skeleton)
    collider_nodes = []
    for ind1 in range(num_nodes): # ind1 is the collider node index
        for ind2 in range(num_nodes):    
        #ind2 and ind3 are the indices for the parent of colliders
            for ind3 in range(num_nodes):
                if (Adj[ind1,ind2] == 1) and ((Adj[ind1,ind3] == 1)) and Skeleton[ind2,ind3] !=1 and ind2 != ind3 and ind1 != ind3 and ind1 != ind2:
                    # print( ind1," is a collider")
                    collider_graph[ind2,ind1]=0
                    collider_graph[ind3,ind1]=0
                    collider_nodes.append(ind1)                       
    # % Orienting other edges if possible
    true_ess_graph = np.copy(collider_graph)
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i>j and collider_graph[i,j]==1 and collider_graph[j,i]==1:
                if i in collider_nodes:
                    true_ess_graph[i,j] = 0
                if j in collider_nodes:
                    true_ess_graph[j,i] = 0
                    
    return true_ess_graph

############################################################################################################################################################  

############################################################################################################################################################
# %% Initialization, data loading and data processing

nNodes = 6
num_nodes = nNodes
n_FFT = 32



data = np.loadtxt("dataSet_nonStrict_Causal.txt", delimiter=' ')

# %% PC test
thresh_freq = np.array([0.0743])
error_freq_PC = np.zeros(len(thresh_freq))
TPR = np.zeros(len(thresh_freq))
FPR = np.zeros(len(thresh_freq))
computation_time = np.zeros(len(thresh_freq))
freq_choice = range(n_FFT)

Adj = np.load("Adj_mat_nonStrict_Causal.npy")
MG,Skeleton = find_moral_graph_and_skeleton_from_adjacency_matrix(Adj)
true_ess_graph = find_true_Essential_graph(Skeleton, Adj)


ind_thresh = 0
start_time = time.time()
C_freq = Run_PC_Wiener_test_FFT_avg(data[:15000,:], n_FFT, thresh_freq[ind_thresh], freq_choice)
Top_est_half = np.zeros([num_nodes,num_nodes])
for c in C_freq:
    if C_freq[c] == None: Top_est_half[c]=1
Top_est = (Top_est_half + Top_est_half.T)
est_essential_graph_freq,est_colliders_freq = estimate_essential_graph_new(Top_est,C_freq)
end_time = time.time()

error_freq_PC[ind_thresh] = sum(sum(est_essential_graph_freq!=true_ess_graph))/sum(sum(true_ess_graph))

[TP, FP, TN, FN] = get_confusion_matrix(true_ess_graph, est_essential_graph_freq)
TPR[ind_thresh] = TP/(TP+FN)
FPR[ind_thresh] = FP/(TN+FP)
computation_time[ind_thresh] = end_time - start_time


print("----------------------------------------______")
print('Threshold=', thresh_freq[ind_thresh])
print(f"TPR = {TPR[ind_thresh]}, FPR = {FPR[ind_thresh]}, CS = {TPR[ind_thresh]-FPR[ind_thresh]}")
    


CS = TPR - FPR



