import numpy as np
import pandas as pd
from sklearn import linear_model
from scipy.stats import pearsonr

class CBSPreProcessor: 

    def __init__(self, data, protected_class_field, protected_class_value, combo_scan, score_variable , event_var, conditional_variable, feature_list, scan_type, scan_feature_list):
        self.data_original = data.copy(deep = True)
        self.data = data
        self.protected_class_field = protected_class_field
        self.protected_class_value = protected_class_value
        self.combo_scan = combo_scan
        self.scan_feature_list = scan_feature_list
        
        ######### variable names to potentially change
        self.score = score_variable
        
        # default model
        self.logistic = linear_model.LogisticRegression(penalty="l2", solver="lbfgs") # uses default scaler of alpha = 1.0
        self.event_var = event_var
        self.conditional_variable = conditional_variable
        

        ########################## change to eventually not be hardcoded ############################
        self.p_hat = "p_hat"
        self.weight = "w"
        self.probability_of_class_membership = "probability_of_class_membership"
        self.class_indicator = "protected_class"
        #############################################################################################
        
        
        self.feature_list = feature_list
        
        
        self.scan_type = scan_type
        
    
    def produce_inverse_propensity_score(self):
        # 1. create indicator variable 
        

        
        self.data[self.class_indicator] = np.where(self.data[self.protected_class_field] == self.protected_class_value, 1, 0)
        
        del self.data[self.protected_class_field]
        
        # 2. train logistic regression 
        X_train = pd.get_dummies(self.data[self.feature_list],drop_first=True)
        
        y_train = self.data[self.class_indicator]
        
        #print(list(X_train))
        
        self.logistic.fit(X_train, y_train)
        print(X_train)
        print(y_train)
        
        probs = self.logistic.predict_proba(X_train)[:,1]
        self.data[self.probability_of_class_membership] = probs
        self.data[self.weight] = ((probs) / (1-probs))
        
    
    def create_compatible_training_test_datasets(self, df1, df2):
        # steps
        #1. remove all columns in df2 which arent in df2
        for col in list(df2):
            if col not in list(df1):
                del df2[col]
        
        #2. add all dummy variables with x = 0 for missing columns from df2
        
        for col in list(df1):
            if col not in list(df2):
                df2[col] = 0
        
        return df1, df2
    
    def z_norm_given_params(self, arr, mean, std):
        return((arr-mean)/std)
    
    def z_norm(self, arr, epsilon=0):
        if arr.std() == 0:
            return arr, 0, 1
        else:
            return (((arr-arr.mean())/(arr.std()+epsilon)),arr.mean() ,arr.std() )
    
    def normalize_data(self,x_train, x_test, cols):
        for col in list(cols):
            x_train_col, u, std = self.z_norm(np.array(x_train[col]))
            x_train[col] = x_train_col
            x_test[col] = self.z_norm_given_params(np.array(x_test[col]), u,std)
        return x_train,x_test
    
    # normalize data test
    def normalize_data(self,x_train, x_test, cols):
        for col in list(cols):
            x_train_col, u, std = self.z_norm(np.array(x_train[col]))
            x_train[col] = x_train_col
            x_test[col] = self.z_norm_given_params(np.array(x_test[col]), u,std)
        return x_train,x_test
    
    def normalize_data_simple(self,x_train, x_test):
        # for training data
        for c in list(x_train):
            std = x_train[c].std()
            u = x_train[c].mean()
            if std == 0: 
                x_train[c] = 0
            else: 
                x_train[c] = ((x_train[c] - u) / std)
        
        for c in list(x_test):
            std = x_test[c].std()
            u = x_test[c].mean()
            if std == 0: 
                x_test[c] = 0
            else: 
                x_test[c] = ((x_test[c] - u) / std)
        return x_train,x_test
            
            
    def remove_correlated_variables(self,x_train, x_test, conditional_variable, cut_off = .5):
        for col in list(x_train):
            corr, _ = pearsonr(x_train[conditional_variable], x_train[col])
            if (np.abs(corr) >= cut_off) and (col != conditional_variable):
                del x_train[col]
                del x_test[col]
        return(x_train, x_test)
            
        
        
    
    def produce_p_hat(self, log_reg_hyperparams):
        #1. divide into two data sets, protected and non protected
        
        df_r_0 = self.data[self.data[self.class_indicator] == 0]
        df_r_1 = self.data[self.data[self.class_indicator] == 1]
        
        # checking if there are no rows in treatment
        
        if len(df_r_1.index) <1 : 
            return None, None, None
        
        if len(df_r_0.index) <1 : 
            return None, None, None

        
        #2. collect column names for features
        col_list = []
        
        
        
        ##########
        if self.scan_type == "prediction_separation":
            x_train = pd.get_dummies(df_r_0[self.feature_list],drop_first=True)
            x_train = pd.concat([x_train,x_train], axis = 0)
            
            w = list(df_r_0[self.weight] * df_r_0[self.event_var]) + list(df_r_0[self.weight] * (1-df_r_0[self.event_var]))
            
            y_train = pd.Series([1]*(len(df_r_0)) + [0]*(len(df_r_0)))
        else:
            #print("here")
            x_train =  pd.get_dummies(df_r_0[self.feature_list],drop_first=True)
            
            w = df_r_0[self.weight]
            
            y_train = df_r_0[self.event_var]
            
        
            
 
        col_list = col_list + self.feature_list
        
        x_test = pd.get_dummies(df_r_1[self.feature_list],drop_first=True)
        #print(self.feature_list)
        
        # collect names of variable names prior to addinx_train[self.conditional_variable] = np.log(x_train[self.conditional_variable] / (1-x_train[self.conditional_variable])g conditional bairable
        dummy_feature_list = list(x_train)
        

        if self.scan_type == "prediction_separation":
            x_train = pd.concat([x_train, pd.concat([df_r_0[self.conditional_variable], df_r_0[self.conditional_variable]], axis = 0)], axis =1 )
            #x_test = pd.concat([x_test, df_r_1[self.conditional_variable]], axis =1 )
        else:
            x_train = pd.concat([x_train, df_r_0[self.conditional_variable]], axis =1 )
        
        x_test = pd.concat([x_test, df_r_1[self.conditional_variable]], axis =1 )
        
        
        # using log odds rather than probabilities for prediction_sufficiency
        if self.scan_type == "prediction_sufficiency":
            x_train[self.conditional_variable] = np.log((x_train[self.conditional_variable] / (1-x_train[self.conditional_variable])) + 1e-6)
            x_test[self.conditional_variable] = np.log((x_test[self.conditional_variable] / (1-x_test[self.conditional_variable])) + 1e-6)
        
        
        
        
        x_train, x_test = self.create_compatible_training_test_datasets(x_train, x_test)
        
        

        #print("compatiable colums")
        #print(list(x_train))
        #print(y_train)
        
        #3. collect outcome


        
        #4. gathering weight
        
        ## checking that there are atleast two classes
        
        #print(y_train)
        
        if len(y_train.unique()) < 2:
            return None, None, None
        
        
        d = {}
        w = np.array(w)
        
        x_train, x_test = self.normalize_data(x_train, x_test, dummy_feature_list)
        #x_train, x_test = self.normalize_data_simple(x_train, x_test)
        
        #if "prediction_sufficiency" == self.scan_type:
            #self.logistic = linear_model.LogisticRegression(solver="lbfgs", fit_intercept = True, penalty = "l2" , C=1)# uses default scaler of alpha = 1.0, 
        #elif "binary_sufficiency" == self.scan_type:

        #    self.logistic = linear_model.LogisticRegression(solver="newton-cg", fit_intercept = True, penalty = "none" )
           
        #elif "prediction_separation" == self.scan_type:
            #self.logistic = linear_model.LogisticRegression(solver="lbfgs", fit_intercept = True, penalty = "l2" , C=1)
        #else:
            #self.logistic = linear_model.LogisticRegression(penalty="l2", solver="lbfgs", fit_intercept = True, C = 4)
        #    self.logistic = linear_model.LogisticRegression(solver="newton-cg", fit_intercept = True, penalty = "none" )
        if "prediction_sufficiency" == self.scan_type:
            #self.logistic = linear_model.LogisticRegression(fit_intercept = False, penalty = "l2" , C=1)
            #self.logistic = linear_model.LogisticRegression(fit_intercept = False)
            self.logistic = linear_model.LogisticRegression()
        else:
            self.logistic = linear_model.LogisticRegression()
        
        print(x_train)
        print(y_train)
        print(w)
        self.logistic.fit(x_train, y_train, sample_weight = w)
        df_r_1[self.p_hat] = self.logistic.predict_proba(x_test[list(x_train)])[:,1]

        
        coefs = list(self.logistic.coef_)[0]
        variable_name = list(x_train)
            
        for idx,coef in enumerate(coefs):
            d[variable_name[idx]] = coef

        d["intercept"] = self.logistic.intercept_


        return df_r_0 , df_r_1, d
        

    def filter_by_outcome(self):
        
        if self.combo_scan  == "both":
            pass
        elif self.combo_scan == "positive":
            self.data = self.data[self.data[self.conditional_variable] ==1 ]
        elif self.combo_scan == "negative":
            self.data = self.data[self.data[self.conditional_variable] ==0 ]
        else:
            print("combo scan param is not recognized")
        


    def run(self, log_reg_hyperparams = None):
        
        # 1. get propensity scores
        # this is the same for all the scans
        # Pr(member of protected class == 1| everything else
        self.produce_inverse_propensity_score()
        
        # 2. filter by outcome
        self.filter_by_outcome()
        
        #3. 
        self.control, self.treatment, coefficient_mapping = self.produce_p_hat(log_reg_hyperparams)
        
        if isinstance(self.control, type(None)) == True:
            return (None, None, None, None, None, None, None, None, None)
        
        
        data_treatment = self.treatment[self.scan_feature_list]
        data_control = self.control[self.scan_feature_list]
        
        data_treatment_predicted_probs = self.treatment[self.event_var]
        data_treatment_equiv_probs = self.treatment[self.p_hat]
        data_control_predicted_probs = self.control[self.event_var]
        
        data_treatment_conditional_var = self.treatment[self.conditional_variable]
        data_control_conditional_var = self.control[self.conditional_variable]
        data_control_weights = self.control[self.weight]
        
        return (data_treatment, data_treatment_predicted_probs, data_treatment_equiv_probs, data_treatment_conditional_var, data_control, data_control_predicted_probs, data_control_conditional_var, data_control_weights, coefficient_mapping)