
import numpy as np
import time
import pandas as pd
import sys

from cbs_preprocessor import CBSPreProcessor
from fsscan import FSScan

import yaml
import copy

class ConditionalBiasScan:
    
    def __init__(self,  protected_class, protected_value, combo_scan,event_var, conditional_variable, scan_params, direction, feature_list, scan_type, scan_feature_list, threshold_bool, theshold_cutoff, printing_on = True):
        #self.dataset_yaml_file = dataset_yaml_file
        
        #self.logger = CBSLogger(results_folder, dataset_yaml_file, protected_class, protected_value, combo_scan,event_var, conditional_variable, scan_params, feature_list, scan_type)
        
        self.combo_scan = combo_scan
        self.protected_class = protected_class
        self.protected_value = protected_value
        self.printing_on = printing_on
        
        self.event_var = event_var
        self.conditional_variable = conditional_variable
        
        
        self.feature_list = copy.deepcopy(feature_list)
        
        # remove protected class from feature list
        #print(self.feature_list)
        #print(self.protected_class)
        self.feature_list.remove(self.protected_class)
        
        self.scan_feature_list =  copy.deepcopy(scan_feature_list)
        
        if self.protected_class in self.scan_feature_list:
            self.scan_feature_list.remove(self.protected_class)
        
        
        self.scan_type = scan_type
        
        #################### NEED TO FIND OTHER WAY TO DO THIS WITHOUT HARD CODING#############
       # self.tilde_p = "tilde_p"
        self.p_bin_var = "p_bin_var"
        #######################################################################################
        self.scan_params = scan_params
        self.direction = direction
        ###################  save threshold information
        self.threshold_bool = threshold_bool
        self.theshold_cutoff = theshold_cutoff
    
    def check_scan_type_combination_param_compatibility(self):
        # should we hardcoded elsewhere
        scan_types = ["binary_separation", "binary_sufficiency", "prediction_separation", "prediction_sufficiency"]
        if self.scan_type not in scan_types:
            message = self.scan_type+ " is not a scan type. " + ", ".join(scan_types)+ " are the avaiable types."
            sys.exit(message)
        
        if (self.scan_type == "prediction_sufficiency") and (self.combo_scan != "both"):
            message = "the prediction sufficiency scan only runs as a combination scan"
            sys.exit(message)
    
    def check_data_types_scan_types(self):
        if self.scan_type != "prediction_separation":
            # all events should be binary 0 or 1
            
            
            #print(self.data)
            #print(self.event_var)
            event_values = list(self.data[self.event_var])
            
            for e in event_values:
                if e not in [1,0]:
                    message = "scan type is " + self.scan_type + " and requires the event variable to reference a column with only binary values (0 or 1)"
                    print(event_values)
                    sys.exit(message)
            
            
        
    def run(self, dataset_configs, data, observed_label, score, log_reg_hyperparams = None):
        
        #-1. check params
        self.check_scan_type_combination_param_compatibility()
        

        
        self.data = copy.deepcopy(data )
        #self.observed_label = observed_label
        #self.tilde_p = dataset_configs["score"]
        self.observed_label = dataset_configs["observed_label"]
        self.score = score
        
        #print(self.tilde_p)
        #print(self.data)
        #print(list(self.data))
        
        # 0 theshold predictions if necessary
        if self.threshold_bool:
            print(self.score)
            self.data[self.p_bin_var] = np.where(self.data[self.score] >= self.theshold_cutoff, 1, 0).tolist()
        
        #1.1 make sure bernoulli scans have binary event values
        self.check_data_types_scan_types()
        
        # 2. run preprocessing
        self.cbs_preprocessor = CBSPreProcessor(self.data, self.protected_class, self.protected_value, self.combo_scan, self.score , self.event_var, self.conditional_variable, self.feature_list, self.scan_type, self.scan_feature_list)
        
        data_treatment, data_treatment_predicted_probs, data_treatment_equiv_probs,data_treatment_conditional_var, data_control, data_control_predicted_probs, data_control_conditional_var, data_control_weights, coefficient_mapping = self.cbs_preprocessor.run(log_reg_hyperparams)
        
        if isinstance(data_treatment, type(None)) == True:
            return None
        
        self.data = self.cbs_preprocessor.data
        
        #3. run scan
        self.fs_scan  = FSScan(self.printing_on)
        stats = self.fs_scan.evaluate_scan( data_treatment,
                      data_treatment_predicted_probs,
                      data_treatment_equiv_probs,
                      data_control,
                      data_control_predicted_probs,
                    self.scan_type,self.direction, **self.scan_params)
        stats["dataset_yaml"] = dataset_configs
        stats["p_hat_coefficient_mapping"] = coefficient_mapping
        stats["control_conditional_var"] = data_control_conditional_var
        stats["treatment_conditional_var"] = data_treatment_conditional_var
        stats["control_weights"] = data_control_weights
        return stats
        


