
import numpy as np

class GaussianLogger:
    
    def __init__(self):
        pass

    def calculate_stats(self,subset, data_treatment, 
        data_treatment_predicted_probs, 
        data_treatment_equiv_probs,
        data_control, 
        data_control_predicted_probs, param, conditional_var, data_treatment_conditional_var,  add_scores = False, direction = None, include_conditional_var_base_rates =False):
        
        
        
        full_n = len(data_treatment) + len(data_control)
        full_average_predicted_probs = (((np.mean(data_treatment_predicted_probs)) * len(data_treatment_predicted_probs)) + ((np.mean(data_control_predicted_probs))*(len(data_control_predicted_probs))))/(full_n)
        
        control_n = len(data_control)
        control_average_predicted_probs = np.mean(data_control_predicted_probs)
        
        treatment_n = len(data_treatment)
        treatment_average_predicted_probs =  np.mean(data_treatment_predicted_probs)
        treatment_average_adj_probs = np.mean(data_treatment_equiv_probs)
        treatment_probs_diff_average = np.mean(np.subtract(data_treatment_predicted_probs, data_treatment_equiv_probs))
        treatment_delta_average = np.mean(np.subtract(np.log(data_treatment_predicted_probs/(1-data_treatment_predicted_probs)), np.log(data_treatment_equiv_probs/(1-data_treatment_equiv_probs))))
        
        if subset == {}:
            stats_dict = {"[full dataset] Number of records" : full_n, 
                     "[full dataset] Average of events (Pr(event|..))": full_average_predicted_probs,
                      
                      
                     "[control group] Number of records" : control_n,
                     "[control group] Average of events (Pr(event|..))" : control_average_predicted_probs,
                      
                      
                     "[treatment group] Number of records": treatment_n, 
                     "[treatment group] Average of events (Pr(event|..))":treatment_average_predicted_probs,
                     "[treatment group] Average of adjusted probabilities (\hat{p})":treatment_average_adj_probs,
                     "[treatment group] Average of differences of predicted and adjusted probabilities (average of \\tilde{p} - \hat{p})":treatment_probs_diff_average,
                     "[treatment group] Average of deltas (average of (ln(\\tilde{p}/(1-\\tilde{p}))) - ln(\hat{p}/(1-\hat{p})))":treatment_delta_average}
             
        
            return(stats_dict)
            
        
        # in subset
        
        subset_i = data_treatment[subset.keys()].isin(subset).all(axis=1) 
        subset_df= data_treatment.loc[subset_i]
        subset_predicted_probs = data_treatment_predicted_probs[subset_i]
        subset_adj_probs = data_treatment_equiv_probs[subset_i]
        
        
        subset_n =  len(subset_df)
        subset_average_predicted_probs = np.mean(subset_predicted_probs)
        subset_average_adj_probs = np.mean(subset_adj_probs)
        subset_probs_diff_average = np.mean(np.subtract(subset_predicted_probs, subset_adj_probs))
        subset_delta_average = np.mean(np.subtract(np.log(subset_predicted_probs/(1-subset_predicted_probs)), np.log(subset_adj_probs/(1-subset_adj_probs))))
        
        # outside of subset
        
     
        complement_subset_df= data_treatment.loc[subset_i == False]
        complement_subset_predicted_probs = data_treatment_predicted_probs[subset_i == False]
        complement_subset_adj_probs = data_treatment_equiv_probs[subset_i == False]
        
        
        complement_subset_n = len(complement_subset_df)
        complement_subset_predicted_probs = np.mean(complement_subset_predicted_probs)
        complement_subset_adj_probs = np.mean(complement_subset_adj_probs)
        complement_subset_probs_diff_average = np.mean(np.subtract(complement_subset_predicted_probs, complement_subset_adj_probs))
        complement_subset_delta_average = np.mean(np.subtract(np.log(complement_subset_predicted_probs/(1-complement_subset_predicted_probs)), np.log(complement_subset_adj_probs/(1-complement_subset_adj_probs))))
        
                # subset in control 
        
        subset_control_i = data_control[subset.keys()].isin(subset).all(axis=1)
        subset_control_df= data_control.loc[subset_control_i]
        subset_control_predicted_probs = data_control_predicted_probs[subset_control_i]
        
        subset_control_n =  len(subset_control_df)
        subset_control_average_predicted_probs = np.mean(subset_control_predicted_probs)

        
        

        
        stats_dict = {"[full dataset] Number of records" : full_n, 
                     "[full dataset] Average of events (Pr(event|..))": full_average_predicted_probs,
                      
                      
                     "[control group] Number of records" : control_n,
                     "[control group] Average of events (Pr(event|..))" : control_average_predicted_probs,
                      
                      
                     "[treatment group] Number of records": treatment_n, 
                     "[treatment group] Average of events  (Pr(event|..))":treatment_average_predicted_probs,
                     "[treatment group] Average of adjusted probabilities (\hat{p})":treatment_average_adj_probs,
                     "[treatment group] Average of differences of predicted and adjusted probabilities (average of \\tilde{p} - \hat{p})":treatment_probs_diff_average,
                     "[treatment group] Average of deltas (average of (ln(\\tilde{p}/(1-\\tilde{p}))) - ln(\hat{p}/(1-\hat{p})))":treatment_delta_average,
                      
                      
                     "[found subset in treatment group] Number of records" : subset_n, 
                     "[found subset in treatment group] Average of events  (Pr(event|..))": subset_average_predicted_probs, 
                     "[found subset in treatment group] Average of adjusted probabilities (\hat{p})": subset_average_adj_probs, 
                     "[found subset in treatment group] Average of differences of predicted and adjusted probabilities (\\tilde{p} - \hat{p})":subset_probs_diff_average, 
                     "[found subset by scan] Average of deltas (average of (ln(\\tilde{p}/(1-\\tilde{p}))) - ln(\hat{p}/(1-\hat{p})))":subset_delta_average, 
                      
                      
                     "[complement of found subset in treatment group] Number of records":complement_subset_n, 
                     "[complement of found subset in treatment group] Average of events  (Pr(event|..))": complement_subset_predicted_probs,
                     "[complement of found subset in treatment group] Average of adjusted probabilities (\hat{p})":complement_subset_adj_probs,
                     "[complement of found subset in treatment group] Average of differences of predicted and adjusted probabilities (\\tilde{p} - \hat{p})":complement_subset_probs_diff_average,
                     "[complement of found subset in treatment group] Average of deltas (average of (ln(\\tilde{p}/(1-\\tilde{p}))) - ln(\hat{p}/(1-\hat{p})))":complement_subset_delta_average, 

                     "[found subset in control group]  Number of records": subset_control_n,
                     "[found subset in control group]  Average of events (Pr(event|..))": subset_control_average_predicted_probs 
                     }
        
        return(stats_dict)