import datetime    
import os
import pandas as pd
import numpy as np

from distribution_specific_logger_functions.gaussian_specific_logger_functions import GaussianLogger
from distribution_specific_logger_functions.bernoulli_specific_logger_functions import BernoulliLogger

class CBSLogger:
    
    def __init__(self, results_folder):
        self.results_folder = results_folder
        
        self.start_time =  datetime.datetime.now().strftime("_%m_%d_%Y_%H%M")
        #self.protected_class = protected_class
        #self.protected_value = protected_value
        #self.params = params
        
            
    
    def checking_for_file(self):
        csv_path = self.results_folder + "/" + self.start_time +".csv"
        isFile = os.path.isfile(csv_path) 
        return (isFile, csv_path)
    
        
    def set_for_run(self, protected_class, protected_value, other_params):
        self.protected_class = protected_class
        self.protected_value = protected_value
        self.other_params = other_params
        
    
        
    
        
    def simplify_subset(self, data, subset):
        subset_simplified = {}
        for k in subset.keys():
            elements = subset[k]
            unique_values = list(data[k].unique())
            if set(elements) != set(unique_values):
                subset_simplified[k] = elements
        return subset_simplified
                
        
    
    def write_results(self, subset, score, param, data_treatment, 
        data_treatment_predicted_probs, 
        data_treatment_equiv_probs,
        data_control, 
        data_control_predicted_probs,  
        data_control_conditional_var,
        data_treatment_conditional_var,
        data_set_specific_yaml, 
        protected_class, 
        protected_value, 
        combo_scan, 
        event ,
        conditional_variable, 
        params, 
        direction,
        feature_list, 
        scan_type, 
        scan_feature_list, append_to_dictionary_names = "", add_scores = False, include_conditional_var_base_rates = False): 
        
        
        # makes the results sparse (ie: if a field has all listed field values, it is removed)
        
        simplified_subset = self.simplify_subset(data_treatment, subset)
        
        initial_record = {"start time of scan": str(self.start_time),
                       "protected class": protected_class,
                       "protected class value": protected_value,
                    "combination scan": combo_scan,
                          "event  (Pr(event|..))": event,
                          "conditional variable  (Pr(..|conditional variable))": conditional_variable,
                          "scan type": scan_type,
                          "direction": direction,
                       "other params for scan": str(params),
                       "found subset": str(simplified_subset), 
                       "subset score": str(score),
                       "param in found subset's alternative hypothesis": str(param),
                        "feature_list": str(feature_list),
                         "scan_attribute_list": str(scan_feature_list)}
        
        if scan_type == "prediction_separation":
            stats_generator = GaussianLogger()
        else:
            stats_generator = BernoulliLogger()
        
        stats_dict = stats_generator.calculate_stats(subset,data_treatment, data_treatment_predicted_probs, data_treatment_equiv_probs, data_control, data_control_predicted_probs, param, data_control_conditional_var,data_treatment_conditional_var, add_scores, direction ,include_conditional_var_base_rates )
        
        stats_dict_temp = {}
        for k in stats_dict: 
            stats_dict_temp[k+append_to_dictionary_names] = stats_dict[k]
        
        stats_dict = stats_dict_temp
        
        final_record = {**initial_record , **stats_dict}
        
        self.isFile, self.csv_path = self.checking_for_file()
        
        #if self.isFile == False: 

         #   results_df = pd.DataFrame([final_record])
         #   results_df.to_csv( self.csv_path, index = False)
         #   self.isFile = True
        #else:
         #   previous_results = pd.read_csv(self.csv_path).to_dict("records")
         #   final_records = previous_results + [final_record]
         #   results_df = pd.DataFrame(final_records)
         #   results_df.to_csv( self.csv_path, index = False)
        
        return stats_dict
            
        
        
        
        
    
    
        
    
    