import numpy as np
import pandas as pd
import warnings
import Utils
from itertools import combinations
from ComponentProcess import ComponentProcess
import CCARankTester


    
def satisfies_latent_common_cause(T_G, subset, data, observed_var_set_data_T_G, NrOfConsideredLagTerms=50, significance_level=0.05):
    """
    Function to check if a given subset of processes satisfies the latent common cause condition.
    
    Parameters:
    T_G (list of ComponentProcess): Augmented process set.
    subset (list of ComponentProcess): Subset of active processes being tested.
    data (pd.DataFrame): Time-series data for the processes.
    NrOfConsideredLagTerms (int): Number of lag terms to consider for testing.
    significance_level (float): Significance level for statistical tests.
    
    Returns:
    bool: True if the subset satisfies the latent common cause condition, False otherwise.
    """
    
    var_setA_names = []
    for p in subset:
        for i in range(NrOfConsideredLagTerms+1):
            var_setA_names.append(p.get_surrogate_A().get_process_name()+"_lag_"+str(i))
    
    P_sib_names = []
    sibling_set = set()
    for p in subset:
        sibling_set.update(p.get_close_observed_effect_set())
        sibling_set.discard(p.get_surrogate_A())
    sibling_set = list(sibling_set)  
    for p in sibling_set:
        for i in range(NrOfConsideredLagTerms+1):
            P_sib_names.append(p.get_surrogate_A().get_process_name()+"_lag_"+str(i))    #Here, it doesn't matter p.get_surrogate_A() or just p, since all p are observed processes.
    
    var_setA_names = var_setA_names + P_sib_names
    A = observed_var_set_data_T_G[var_setA_names] #It works since all the siblings are observed processes.
    
    
    temp = []
    for p in subset:
        temp.append(p.get_surrogate_A().get_process_name()+"_lag_0")
    
    var_setB_names = list(set(observed_var_set_data_T_G.columns.copy()) - set(temp))
    B = observed_var_set_data_T_G[var_setB_names]
    
    expected_rank = len(P_sib_names) + 2*NrOfConsideredLagTerms + 1 
    # print(len(var_setA_names))
    # print(expected_rank)
    print("expected_rank =",expected_rank, "length of A =", len(var_setA_names), "length of B =", len(var_setB_names))
    if (len(var_setB_names) < expected_rank): #In case that there are only two processes left.
        return True
    if (len(var_setB_names) == expected_rank): #In case that there only one latent variable left.
        return False
    rank_test_result, p_values = CCARankTester.test_rank_equal_expected_phase2(A, B, expected_rank, significance_level)
    
    print("The p_values are",str(p_values), "when supposing",str([p.get_process_name() for p in subset])," has a common latent direct cause with expected rank",str(expected_rank)+".")  
    return rank_test_result     # True if satisfing the requirements
        

  
  
    
    
    



def discovering_new_latent_component_processes(allKnownCompProcesses, A_G, T_G, data, observed_var_set_data_T_G, currentNrOfLatentProcesses, NrOfConsideredLagTerms=50, significance_level=0.05):
    """
    Discover new latent component processes based on given inputs.

    Parameters:
    allKnownCompProcesses (G) (list of instances of process): Being used for reconstructing the causal graph later (set of directed edges).
    A_G (set): Active process set (instances of ComponentProcess).
    T_G (set): Augmented process set (instances of ComponentProcess).

    Returns:
    tuple: Updated (allKnownCompProcesses, A_G, T_G).
    """
    
    # Initialize cluster set and group size
    cluster_set = []
    group_size = 2

    # Iterate over subsets of active_processes of size `group_size`
    for subset in combinations(A_G, group_size):
        # Check if the condition for identifying latent common cause is satisfied
        if satisfies_latent_common_cause(T_G, subset, data, observed_var_set_data_T_G, NrOfConsideredLagTerms, significance_level):
            cluster_set.append(list(subset))


    # Update the cluster set with the merged clusters
    merged_clusters = Utils.merge_list_modified(cluster_set)
    cluster_set = merged_clusters

    # For each merged cluster, introduce a new latent component process
    for cluster in cluster_set:
        cluster = list(cluster)
        
        # Create a new latent component process
        latent_process_name = f"L_{currentNrOfLatentProcesses + 1}"  # Give a unique name to the latent process
        new_latent_process = ComponentProcess(
            process_name=latent_process_name, isObserved=False, surrogate_A=cluster[0].get_surrogate_A(), surrogate_B=cluster[1].get_surrogate_A(), close_observed_effect_set = set(), who_offer_surrogate_A = cluster[0]
        )
        currentNrOfLatentProcesses = currentNrOfLatentProcesses+1
        for p in cluster:
            new_latent_process.add_direct_effect(p)  # Add causal edge (new -> p)
            new_latent_process.add_close_observed_effect_set(p)
            p.add_direct_cause(new_latent_process)
            # print(p.get_process_name())

        direct_effects = [p.get_process_name() for p in new_latent_process.get_direct_effects()]
        print(str(direct_effects)+" have the direct cause "+new_latent_process.get_process_name()+".")
        
        # Update the active and augmented process sets
        for p in cluster:
            A_G.remove(p)
            if not p.isObserved:
                T_G.remove(p)
        A_G.append(new_latent_process)   
        T_G.append(new_latent_process)
        allKnownCompProcesses.append(new_latent_process)     
    
    # Return the updated graph, active processes, and augmented processes
    return A_G, T_G, allKnownCompProcesses, currentNrOfLatentProcesses
   
 




