import numpy as np
import pandas as pd
import warnings

from ComponentProcess import ComponentProcess
import Phase1_IdentifyingCausalRelations as Phase1
import Phase2_DiscoveringNewLatentComponentProcesses as Phase2
import SimulationTimeSeriesData as SD_timeSeries #simulate data
import Simulation_Hawkes_data_from_tick as SD_tick #simulate data
import Utils
import MakeGraph

def MainPOMHP(data, alpha=[0.05,0.05], maxNrOfDirectCausesForPhase1 = 3, NrOfConsideredLagTerms=50):     #set a contraints of maxNrOfDirectCauses for phase1 to reduce calculation complexity
    """
    Implements the two-phase iterative algorithm.
    """
    # Initialize the causal graph and sets
    indices=list(data.columns)
    allKnownCompProcesses = [] # used to save all the discovered information and represent causal graph later
    A_G = []  # Active process list initialized to observed processes
    T_G = []  # Augmented process list initialized to observed processes
    A_G_copy = A_G.copy()
    
    for col in indices:
        observed_instance = ComponentProcess(
            process_name=col, isObserved=True, surrogate_A=None, surrogate_B=None, close_observed_effect_set = set()
        )
        observed_instance.observed_initialize_surrogate_A(observed_instance)
        observed_instance.observed_initialize_surrogate_B(observed_instance)
        observed_instance.observed_initialize_close_observed_effect_set(observed_instance)
        observed_instance.observed_initialize_set_who_offer_surrogate_A(observed_instance)
        
        allKnownCompProcesses.append(observed_instance)
        A_G.append(observed_instance)
        T_G.append(observed_instance)
    
    
    # Since T_G always contains all observed processes, we can reuse observed_var_set_data_T_G, regardless of changes in T_G.
    observed_T_G_names = [t.get_process_name() for t in T_G if t.get_isObserved()]
    lag_set = [i for i in range(NrOfConsideredLagTerms+1)]
    observed_var_set_data_T_G = Utils.generate_lagged_time_series(data[observed_T_G_names],lag_set)
    # print(observed_var_set_data_T_G.columns)
    
    # Algorithm starts 
    currentNrOfLatentProcesses = 0
    numOfIter = 0   
    # Repeat until stopping criteria are met
    while len(A_G) > 0 and A_G_copy != A_G:
        A_G_copy = A_G.copy()  
        numOfIter +=1  
        
        # phase1
        print("++++++++ Phase 1 start +++++++")
        A_G, T_G, allKnownCompProcesses = Phase1.identifying_causal_relations(allKnownCompProcesses, A_G, T_G, data, observed_var_set_data_T_G, maxNrOfDirectCausesForPhase1 , NrOfConsideredLagTerms, significance_level=alpha[0])
    
        
        # phase2
        print('')
        print("++++++++ Phase 2 start +++++++")
        A_G, T_G, allKnownCompProcesses, currentNrOfLatentProcesses = Phase2.discovering_new_latent_component_processes(allKnownCompProcesses, A_G, T_G, data, observed_var_set_data_T_G, currentNrOfLatentProcesses, NrOfConsideredLagTerms, significance_level=alpha[1])
        
        #print summary
        print('')
        print("++++++++++++ Summary ++++++++++++++++")
        print("Active Processes Collection:", end='')
        for p in A_G:
            print(p.get_process_name(), end=' ')
        print("")
        print("++++++++++++ The",str(numOfIter)+"th","iteration end ++++++++++++++++")
        print("\n")
        
        # # need add further condition to distinguish Fig4a and Fig4b
        # if len(A_G) == 1 and A_G[0].get_isObserved() == False:
        #     break
    
    direct_causal_relation ={}
    for process in allKnownCompProcesses:
        p_name = process.get_process_name()
        if p_name not in direct_causal_relation.keys():
            direct_causal_relation[p_name] =[]
            for effect in process.get_direct_effects():
                direct_causal_relation[p_name].append(effect.get_process_name())
    
    m=UpdateGraph(indices,direct_causal_relation)
    file_name = 'adj_matrix.txt'
    with open(file_name, 'a') as f:
        df_str = m.to_string(index=True)
        f.write(df_str + '\n\n') 
    
    
    MakeGraph.Make_mergedPyramid_graph(allKnownCompProcesses)
    
    print("Causal relations:")
    for process in allKnownCompProcesses:
        direct_effects = [p.get_process_name() for p in process.get_direct_effects()]
        print("The process",process.get_process_name(),"has direct effects:",str(direct_effects))
    print("\n")
    
    return None








def UpdateGraph(obserd,LatentIndex):
    key =LatentIndex.keys()
    Variables=[]
    for i in obserd:
        Variables.append(i)
    for i in key:
        if i not in Variables:
            Variables.append(i)
    n=len(Variables)
    indexs=Variables
    matrix=pd.DataFrame(np.zeros((n,n),dtype=np.int32))
    matrix.columns=indexs
    matrix.index=indexs
    for i in key:
        clu=LatentIndex[i]
        for j in clu:
            matrix[i][j]=1
    print(matrix)
    return matrix
    
    
    
  
#simulation: Randomly generate simulation data to validate our method
def main():
    for i in range(10):
        #generate simulation data
        data=SD_timeSeries.CaseFig1b(30000)
        #set the alpha value
        alpha_phase1 = 0.1 
        alpha_phase2 = 0.1  
        alpha = [alpha_phase1,alpha_phase2]
        
        #set contraints to reduce calculation complexity
        maxNrOfDirectCausesForPhase1 = 2
        NrOfConsideredLagTerms= 101 #1000 #51
        #esimate causal structure from observed data, plot the graph
        MainPOMHP(data, alpha, maxNrOfDirectCausesForPhase1, NrOfConsideredLagTerms)
        

if __name__ == '__main__':
    main()  
    
    
    
    
    

    
    



