import numpy as np
import pandas as pd
import warnings
import Utils
from itertools import combinations
from ComponentProcess import ComponentProcess
import CCARankTester



def get_known_sibling_variables(N_1, P_G):
    """
    Get the observed sibling set of the direct effects of a latent process.
    """
    sibling_set = set()
    sibling_set.update(N_1.get_close_observed_effect_set())
    sibling_set.discard(N_1.get_surrogate_A())
    for p in P_G:
        sibling_set.update(p.get_who_offer_surrogate_A().get_close_observed_effect_set())
    for p in P_G:
        sibling_set.discard(p.get_surrogate_A())
     
    return list(sibling_set)




def satisfies_identifying_direct_causes(T_G, P_G, N_1, data, observed_var_set_data_T_G, NrOfConsideredLagTerms=50, significance_level=0.05):
    """
    Function for checking the condition of identifying direct causes.
    
    Parameters:
    T_G (list of instances of process): Augmented process set.
    P_G (list of instances of process): Selected subset of augmented process set.
    N_1 (process): Selected component process.
    data: A pandas dataframe
    NrOfConsideredLagTerms: the number of considered lag terms
    significance_level: is the significance_level for rank hypothesis test

    Returns:
    bool: True if the condition is satisfied, otherwise False.
    """
    
    observed_P_G_names = [p.get_process_name() for p in P_G if p.get_isObserved()]
    latent_P_G_names = [p.get_process_name() for p in P_G if not p.get_isObserved()]
    
    # Case 1: N_1 is observed, and P_G contains only observed processes
    if N_1.get_isObserved() and len(observed_P_G_names) == len(P_G): # The second part to ensure all in P_G are observed.
    
        var_setA_names = []
        var_setA_names.append(N_1.get_process_name()+"_lag_0")
        for name in observed_P_G_names:
            for i in range(1, NrOfConsideredLagTerms+1):
                var_setA_names.append(name+"_lag_"+str(i))
        A = observed_var_set_data_T_G[var_setA_names]
        
        var_setB_names = list(observed_var_set_data_T_G.columns)
        var_setB_names.remove(N_1.get_process_name()+"_lag_0")
        # print(var_setB_names)
        # print(N_1.get_process_name()+"_lag_0")
        B = observed_var_set_data_T_G[var_setB_names]
        
        expected_rank = NrOfConsideredLagTerms*len(observed_P_G_names)
        rank_test_result, p_values = CCARankTester.test_rank_equal_expected_phase1(A, B, expected_rank, significance_level)
           
    # Case 2:
    else:
        var_setA_names = []
        var_setA_names.append(N_1.get_surrogate_A().get_process_name()+"_lag_0")
        
        print("\n")
        print("The N_1 process is", N_1.get_process_name(),"with its observed surrogate",N_1.get_surrogate_A().get_process_name(),".")
        sib_process_set = []
        for sib in N_1.get_close_observed_effect_set():
            sib_process_set.append(sib.get_process_name())
        print("The sibling process set for  N_1 is",sib_process_set,".")
        print("The P_G process set is",[p.get_process_name() for p in P_G],".")
        for p in P_G:
            for sib in p.get_who_offer_surrogate_A().get_close_observed_effect_set():
                print("The sibling process set for a process of P_G",p.get_process_name(),"surrogate_A", str(p.get_who_offer_surrogate_A().get_process_name()), "is",str(sib.get_process_name()),".")

        
        for p in P_G: 
            if p.get_isObserved() == False:
                var_setA_names.append(p.get_surrogate_A().get_process_name()+"_lag_0")
        
        temp = var_setA_names.copy()
        
        P_sur_names = []
        for i in range(1, NrOfConsideredLagTerms+1):
            P_sur_names.append(N_1.get_surrogate_A().get_process_name()+"_lag_"+str(i))
        for p in P_G:
            for i in range(1, NrOfConsideredLagTerms+1):
                P_sur_names.append(p.get_surrogate_A().get_process_name()+"_lag_"+str(i))
        
        P_sib_names = []
        for p in get_known_sibling_variables(N_1, P_G):
            for i in range(0, NrOfConsideredLagTerms+1):
                P_sib_names.append(p.get_surrogate_A().get_process_name()+"_lag_"+str(i))
        
        var_setA_names = var_setA_names + P_sur_names + P_sib_names
        var_setA_names = set(var_setA_names) 
        A = observed_var_set_data_T_G[list(var_setA_names)]
        
        var_setB_names = list(set(observed_var_set_data_T_G.columns.copy()) - set(temp))
        B = observed_var_set_data_T_G[var_setB_names]
         
        # if len(var_setA_names)-1 != len(P_sur_names) + len(P_sib_names) + len(latent_P_G_names):
        #     print("Some thing wrong in the code.")
        expected_rank = len(var_setA_names)-1
        print("expected_rank =",expected_rank, "length of A =", len(var_setA_names), "length of B =", len(var_setB_names))
        if (len(var_setB_names) < expected_rank): #In case that there are only two processes left.
            return True
        if (len(var_setB_names) == expected_rank): #In case that there only one latent variable left.
            return False
        rank_test_result, p_values = CCARankTester.test_rank_equal_expected_phase1(A, B, expected_rank, significance_level)

    print("The p_values are",str(p_values), "when supposing",N_1.get_process_name()," has the direct causes",str([p.get_process_name() for p in P_G]),"with expected rank",str(expected_rank)+".") 
    return rank_test_result     # True if satisfing the requirements
        



#special case test when N_1 doesn't have direct causes.
def special_satisfies_identifying_direct_causes(T_G, N_1, data, observed_var_set_data_T_G, NrOfConsideredLagTerms=50, significance_level=0.05):

    var_setA_names = []
    var_setA_names.append(N_1.get_surrogate_A().get_process_name()+"_lag_0")
    P_sib_names = []
    for p in get_known_sibling_variables(N_1, []):
        for i in range(0, NrOfConsideredLagTerms+1):
            P_sib_names.append(p.get_surrogate_A().get_process_name()+"_lag_"+str(i))
    
    var_setA_names = var_setA_names + P_sib_names
    A = observed_var_set_data_T_G[var_setA_names]
     
    
    var_setB_names = list(observed_var_set_data_T_G.columns)
    var_setB_names.remove(N_1.get_process_name()+"_lag_0")
    # print(var_setB_names)
    # print(N_1.get_process_name()+"_lag_0")
    B = observed_var_set_data_T_G[var_setB_names]
    
    print(A.shape, B.shape)
    expected_rank = 0
    rank_test_result, p_values = CCARankTester.test_rank_equal_expected_phase1(A, B, expected_rank, significance_level)

    print("The p_values are",str(p_values), "when supposing",N_1.get_process_name()," has no direct causes with expected rank",str(expected_rank)+".") 
    return rank_test_result     # True if satisfing the requirements
          









# In phase 1, the allKnownCompProcesses doesn't need to change.
def identifying_causal_relations(allKnownCompProcesses, A_G, T_G, data, observed_var_set_data_T_G, maxNrOfDirectCauses=3, NrOfConsideredLagTerms=50, significance_level=0.05):
    """
    Identifies causal relations and updates the causal graph and process sets.

    Parameters:
    allKnownCompProcesses (G) (list of instances of process): Being used for reconstructing the causal graph later (set of directed edges).
    A_G (set): Active process set.
    T_G (set): Augmented process set.

    Returns:
    tuple: Updated (allKnownCompProcesses, A_G, T_G).
    """
    repeat = True  # To track updates in A_G
    while repeat:
        repeat = False
        A_G_copy = list(A_G)
        for N_1 in A_G_copy:  # Iterate over a copy of A_G to allow modification
            
            # # Since the rank test is not stable, we can replace this special case test with independence test at the beginning.
            # # Special case when N_1 doesn't have direct causes.
            # if special_satisfies_identifying_direct_causes(T_G, N_1, data, observed_var_set_data_T_G, NrOfConsideredLagTerms, significance_level):
            #     #we don't need to update the causal graph with N_1's relationships
            #     # Remove N_1 from the active process set
            #     A_G.remove(N_1)
            #     updated = True
            #     # If N_1 is a latent component, remove it from the augmented set
            #     if N_1.isObserved == False: # Check if N_1 is latent
            #         T_G.remove(N_1)
            #     # Break and restart
            #     break

            updated = False
            # General case when P_G is not empty.
            for Len in range(1, len(T_G)+1):
                if Len > maxNrOfDirectCauses:
                    break   # if reach maxNrOfDirectCauses limit = True
                subsets = list(combinations(T_G, Len))  # Generate subsets of size `Len`
                for P_G in subsets:
                    # print(len(P_G))
                    if satisfies_identifying_direct_causes(T_G, list(P_G), N_1, data, observed_var_set_data_T_G, NrOfConsideredLagTerms, significance_level):
                        # Update the causal graph with N_1's relationships
                        for p in P_G:
                            N_1.add_direct_cause(p)  # Add causal edge (P -> N_1)
                            p.add_direct_effect(N_1)
                            p.add_close_observed_effect_set(N_1)
                        
                        direct_causes = [p.get_process_name() for p in N_1.get_direct_causes()]
                        print(N_1.get_process_name()+" has the direct causes "+str(direct_causes)+".")

                        # Remove N_1 from the active process set
                        A_G.remove(N_1)
                        updated = True
                        repeat = True

                        # If N_1 is a latent component, remove it from the augmented set
                        if N_1.isObserved == False: # Check if N_1 is latent
                            T_G.remove(N_1)

                        # Break and restart
                        break
                if updated:  # If A_G was updated, restart the process
                    break
            # if updated:  # If A_G was updated, restart the process
            #     break
    return A_G, T_G, allKnownCompProcesses



