import numpy as np

import os
import sys
sys.path.append("..")
sys.path.append("")

from scan_distribution_specific_score_functions.bernoulli_scan import BernoulliScan

class BernoulliLogger:
    
    def __init__(self):
        pass

    def calculate_stats(self,subset, data_treatment, 
        data_treatment_predicted_probs, 
        data_treatment_equiv_probs,
        data_control, 
        data_control_predicted_probs, q, data_control_conditional_var, data_treatment_conditional_var, add_scores = False, direction = None, include_conditional_var_base_rates = False):
        
        
        
        full_n = len(data_treatment) + len(data_control)
        full_average_predicted_probs = (((np.mean(data_treatment_predicted_probs)) * len(data_treatment_predicted_probs)) + ((np.mean(data_control_predicted_probs))*(len(data_control_predicted_probs))))/(full_n)
        
        control_n = len(data_control)
        control_average_predicted_probs = np.mean(data_control_predicted_probs)
        
        treatment_n = len(data_treatment)
        treatment_average_predicted_probs =  np.mean(data_treatment_predicted_probs)
        treatment_average_adj_probs = np.mean(data_treatment_equiv_probs)
        treatment_probs_diff_average = np.mean(data_treatment_equiv_probs/(1-data_treatment_equiv_probs))

        
        if subset == {}:
                    
            stats_dict = {"[full dataset] Number of records" : full_n, 
                         "[full dataset] Average of events (Pr(event|..))": full_average_predicted_probs,


                         "[control group] Number of records" : control_n,
                         "[control group] Average of events (Pr(event|..))" : control_average_predicted_probs,


                         "[treatment group] Number of records": treatment_n, 
                         "[treatment group] Average of events (Pr(event|..))":treatment_average_predicted_probs,
                         "[treatment group] Average of adjusted probabilities (\hat{p})":treatment_average_adj_probs,
                         "[treatment group] Average of H_0 (\hat{p}/(1-\hat{p}))":treatment_average_adj_probs}
             
        
            return(stats_dict)
            
        
        # in subset
        
        subset_i = data_treatment[subset.keys()].isin(subset).all(axis=1) 
        subset_df= data_treatment.loc[subset_i]
        subset_predicted_probs = data_treatment_predicted_probs[subset_i]
        subset_adj_probs = data_treatment_equiv_probs[subset_i]
        subset_conditional_var = data_treatment_conditional_var[subset_i]
        
        
        subset_n =  len(subset_df)
        subset_average_predicted_probs = np.mean(subset_predicted_probs)
        subset_average_adj_probs = np.mean(subset_adj_probs)
        subset_probs_diff_average =  np.mean(subset_adj_probs/(1-subset_adj_probs))
        subset_delta_average = np.mean(q*(subset_adj_probs/(1-subset_adj_probs)))
        subset_ratio = np.mean(np.log((q*(subset_adj_probs/(1-subset_adj_probs)))/(subset_adj_probs/(1-subset_adj_probs))))
        
        # outside of subset in treatment
        
     
        complement_subset_df= data_treatment.loc[subset_i == False]
        complement_subset_predicted_probs = data_treatment_predicted_probs[subset_i == False]
        complement_subset_adj_probs = data_treatment_equiv_probs[subset_i == False]
        
        
        complement_subset_n = len(complement_subset_df)
        complement_subset_predicted_probs = np.mean(complement_subset_predicted_probs)
        complement_subset_adj_probs = np.mean(complement_subset_adj_probs)
        complement_subset_probs_diff_average = np.mean(complement_subset_adj_probs/(1-complement_subset_adj_probs))
        
        # subset in control 
        
        subset_control_i = data_control[subset.keys()].isin(subset).all(axis=1)
        subset_control_df= data_control.loc[subset_control_i]
        subset_control_predicted_probs = data_control_predicted_probs[subset_control_i]
        subset_control_conditional_var = data_control_conditional_var[subset_control_i]
        
        subset_control_n =  len(subset_control_df)
        subset_control_average_predicted_probs = np.mean(subset_control_predicted_probs)
        subset_control_average_conditional_var = np.mean(subset_control_conditional_var)
        subset_treatment_average_conditional_var = np.mean(subset_conditional_var)
        
        #def score_current_subset(self, coordinates,p_hat,event,penalty,current_subset,direction='positive'):
        



        
               
        stats_dict = {"[full dataset] Number of records" : full_n, 
                     "[full dataset] Average of events (Pr(event|..))": full_average_predicted_probs,
                      
                      
                     "[control group] Number of records" : control_n,
                     "[control group] Average of events (Pr(event|..))" : control_average_predicted_probs,
                      
                      
                     "[treatment group] Number of records": treatment_n, 
                     "[treatment group] Average of events (Pr(event|..))":treatment_average_predicted_probs,
                     "[treatment group] Average of adjusted probabilities (\hat{p})":treatment_average_adj_probs,
                     "[treatment group] Average of H_0 (\hat{p}/(1-\hat{p}))":treatment_average_adj_probs,
                      
                      
                     "[found subset in treatment group] Number of records" : subset_n, 
                     "[found subset in treatment group] Average of events (Pr(event|..))": subset_average_predicted_probs, 
                     "[found subset in treatment group] Average of adjusted probabilities (\hat{p})": subset_average_adj_probs, 
                     "[found subset in treatment group] Average of H_0 (\hat{p}/(1-\hat{p}))":subset_probs_diff_average, 
                     "[found subset in treatment group] Average of H_1 (q*(\hat{p}/(1-\hat{p})))":subset_delta_average, 
                     "[found subset in treatment group] Average of log-likelihoods (ln(H_1 / H_0))":subset_ratio,
                     "[found subset in treatment group]  Average of conditional variable (Pr(..|conditional_var))": subset_treatment_average_conditional_var,
                      
                      
                      
                     "[complement of found subset in treatment group] Number of records":complement_subset_n, 
                     "[complement of found subset in treatment group] Average of events (Pr(event|..))": complement_subset_predicted_probs,
                     "[complement of found subset in treatment group] Average of adjusted probabilities (\hat{p})":complement_subset_adj_probs,
                     "[complement of found subset in treatment group] Average of H_0 (\hat{p}/(1-\hat{p}))":complement_subset_probs_diff_average,
                     
                      "[found subset in control group]  Number of records": subset_control_n,
                      "[found subset in control group]  Average of events (Pr(event|..))": subset_control_average_predicted_probs,
                      "[found subset in control group]  Average of conditional variable (Pr(..|conditional_var))": subset_control_average_conditional_var}
        
        if add_scores == True:
            
            bernoulli_scan = BernoulliScan()
            
            pen_score, q_mle = bernoulli_scan.score_current_subset(data_treatment, data_treatment_equiv_probs, data_treatment_predicted_probs, 0, subset, direction = direction)
            stats_dict["cbs_score_recalculated"] = pen_score
            stats_dict["cbs_param_recalculated"] = q_mle
            
            if q != np.inf:
                pen_score, q_mle = bernoulli_scan.score_current_subset(data_treatment, data_treatment_equiv_probs, data_treatment_predicted_probs, 0, subset, direction = direction, q_given = q)
                stats_dict["cbs_score_validate"] = pen_score
                stats_dict["cbs_param_validate"] = q_mle
        
        if include_conditional_var_base_rates == True:
            subset_base_rate_cond_var_1 = np.mean(subset_predicted_probs[subset_conditional_var == 1])
            subset_base_rate_cond_var_0 = np.mean(subset_predicted_probs[subset_conditional_var == 0])
            stats_dict["[found subset in treatment group]  Average of Pr(Y=1|C=1)"] = subset_base_rate_cond_var_1
            stats_dict["[found subset in treatment group]  Average of Pr(Y=1|C=0)"] = subset_base_rate_cond_var_0
            
            subset_p_hat_cond_var_1 = np.mean(subset_adj_probs[subset_conditional_var == 1])
            subset_p_hat_cond_var_0 = np.mean(subset_adj_probs[subset_conditional_var == 0])
            stats_dict["[found subset in treatment group]  Average of p_hat|C=1"] = subset_p_hat_cond_var_1
            
            stats_dict["[found subset in treatment group]  Average of p_hat|C=0"] = subset_p_hat_cond_var_0
            
            
            
            subset_control_base_rate_cond_var_1 = np.mean(subset_control_predicted_probs[subset_control_conditional_var == 1])
            subset_control_base_rate_cond_var_0 = np.mean(subset_control_predicted_probs[subset_control_conditional_var == 0])
            stats_dict["[found subset in control group]  Average of Pr(Y=1|C=1)"] = subset_control_base_rate_cond_var_1
            stats_dict["[found subset in control group]  Average of Pr(Y=1|C=0)"] = subset_control_base_rate_cond_var_0
            
            
            
            
            
            
                
                
        
        return(stats_dict)