import pandas as pd
import itertools
import random
from scipy.stats import bernoulli
import numpy as np
import copy
import traceback


import os
import sys
sys.path.append("..")
sys.path.append("../..")
sys.path.append("../*")
sys.path.append("../Benchmark\ Tests/")
sys.path.append("../Benchmark\ Tests")
sys.path.append("")
sys.path.append('../')
sys.path.append("../Benchmark Tests")




from util_functions_cbs_benchmarking import *

from conditional_bias_scan import ConditionalBiasScan
from cbs_preprocessor import CBSPreProcessor
from cbs_logger import CBSLogger
from yaml_funcs import YamlFunctions
from dataset_specific_funcs import DatasetSpecificFuncs
from sklearn import linear_model


import matplotlib.pyplot as plt
import time

from  scipy.stats import pearsonr

import multiprocessing as mp

# read in data set

timestr = time.strftime("_%Y%m%d_%H%M%S")
##experiment_name = "1A"


df = pd.read_csv("../../toy_datasets/COMPAS.csv")


# creating deep copy of original dataset
df_copy = df.copy(deep =True)
deep_df_copy = df_copy.copy(deep = True)

# list of feature columns
cols_copy =  ["Under 25", "Prior Offenses", "Race", "ChargeDegree", "Sex"]

yaml_configs_path = "../fsscan_yamls/fsscan_configs-COMPAS-permutation_testing.yaml"






#making temp copies

## needed for benchmark tests
#dataset = None
#attributes = None
#centered = True
#######

# needed for parrallel processing
active_workers = 0
completed_workers = 0

active_processes_list = []
active_worker_constant = 10
#


def run_wrapper( run_info):
    #run_info = run_info[0]
    
    scan_params = run_info["scan_params"]
    #print(scan_params)
    #print(list(scan_params))
    scan = scan_params["scan_info"]
    dataset_yaml = scan_params["dataset_yaml"]
    data = scan_params["data"]
    p_bin_var = scan_params["p_bin_var"]
    tilde_probability_var = scan_params["tilde_probability_var"]
    
    
    key, key_value = run_info["protected_class"] 
    print("printing scan type:::")
    print(scan["scan_type"])
    print(str(run_info["run_number"]))
    
    cbs = ConditionalBiasScan( scan["protected_class"], scan["protected_value"], scan["combo"], scan["event"] ,scan["conditional_variable"], fsscan_configs["fsscan_params"],scan["direction"], scan["feature_list"], scan["scan_type"] , scan["scan_feature_list"], scan["threshold_probability"], scan["threshold_cutoff"] )
    results =  cbs.run(dataset_yaml, data,p_bin_var, tilde_probability_var)
    
    stats_dict = cbs_logger.write_results(results["best_subset"], 
                             results["best_score"], 
                             results["best_param"], 
                             results["treatment"], 
                             results["treatment_events"], 
                             results["treatment_p_hat"], 
                             results["controls"],
                             results["control_events"],
                            results["control_conditional_var"],
                            results["treatment_conditional_var"],
                            results["dataset_yaml"],
                            scan["protected_class"],
                            scan["protected_value"], 
                            scan["combo"],
                            scan["event"],
                            scan["conditional_variable"],
                            fsscan_configs["fsscan_params"],
                            scan["direction"],
                            scan["feature_list"],
                            scan["scan_type"],
                            scan["scan_feature_list"],
                            "", add_scores = True, include_conditional_var_base_rates = True )
    
    del run_info["scan_params"]
    #experiment_name
    

    
    print("coefficients used for variable of logistic regression used to produce \hat p: ")
    print(results["p_hat_coefficient_mapping"])
   # print("tilde_p's coefficient is "+str(results["p_hat_coefficient_mapping"][scan["conditional_variable"]]))
            
    print("best subset found : " + str(results["best_subset"]))
    print("best score : " + str(results["best_score"]))
    print("param for best scoring subset : "+ str(results["best_param"]))
            
    s_found_subset = results["best_subset"]
    print(s_found_subset)

            
    run_info["combo"] = scan["combo"]
    run_info["event"] = scan["event"]
    run_info["conditional_variable"] = scan["conditional_variable"]
    run_info["fsscan_params"] = fsscan_configs["fsscan_params"]
    run_info["direction"] = scan["direction"]
    run_info["feature_list"] = scan["feature_list"]
    run_info["scan_type"] = scan["scan_type"] 
    run_info["scan_feature_list"] = scan["scan_feature_list"]
    run_info["threshold_probability"] = scan["threshold_probability"]
    run_info["threshold_cutoff"] = scan["threshold_cutoff"]
            
            
    run_info["best_subset"]  = results["best_subset"] 
    run_info["best_score"] = results["best_score"]
    run_info["best_param"] =  results["best_param"]

    run_info["cbs_param"] = results["best_param"]
    run_info["cbs_score"] = results["best_score"]
    run_info["p_hat_coefficient_mapping"] = results["p_hat_coefficient_mapping"]
            
    run_info = {**run_info , **stats_dict}
    #run_info = {**run_info , **other_dict}
    
    file_name = "permutation_testing_results/_"+ timestr +"/runs_results"+ "/_run_num_" + str(run_info["run_number"]) + "_"+scan["scan_type"]+"_"+key+"_"+str(key_value)+".csv"
    pd.DataFrame([run_info]).to_csv(file_name)

    
# create new folder
folder_path = "permutation_testing_results/"+"_"+ timestr 
folder_path_bias = "permutation_testing_results/"+"_"+ timestr +"/runs_results"
folder_path_org = "permutation_testing_results/_" +timestr +"/original_data_sets"
os.mkdir(folder_path)
os.mkdir(folder_path_bias)
os.mkdir(folder_path_org)



########################################################################
# variable to pick !!!!!!!! below
key = "Under 25"
key_values_list= list(df_copy[key].unique())
print(key_values_list)

############################################

unsuccessful = True
while unsuccessful == True:
    try:
        run_infos = [] 
        for run_number in range(0,200):
            print("running run number "+str(run_number))
            
            
            
            for key_v in key_values_list: 
                print(key_v)
                run_info = {}
                run_info["run_number"] = run_number




                #print("Protected class:"+ key+ " : " + str(key_value))
            
                #print("Coefficients used to protected true log-odds: "+ str(coefficient_map))



                run_info["protected_class"] = (key, key_v)


                #df_t.to_csv(folder_path_org+"/_run_num_"+str(run_number)+".csv")

                # DIFFERENCES IN INJECTED BIAS
            



                yaml_funcs = YamlFunctions(yaml_configs_path)
                fsscan_configs = yaml_funcs.run()

                cbs_logger = CBSLogger(fsscan_configs["results_folder"])

                        # performing initial data preprocessing, in this case there is not any
                data_specs_func = DatasetSpecificFuncs(fsscan_configs["data_set_specific_yaml"], "tilde_p", "p_bin_var")
                data, dataset_yaml, tilde_probability_var, p_bin_var = data_specs_func.run()
                    
                print(tilde_probability_var)
                
                ###########
                #### PERMUTATION HAPPENS HERE !!!!!!!!
                data[key] = np.random.permutation(data[key])
                ###########
                
                data.to_csv(folder_path_org+"/_run_num_"+str(run_number)+"_"+key+"_"+str(key_v)+".csv")

                    # producing all scans in config file
                scans = yaml_funcs.produce_scans(data, key, key_v)
                for scan in scans:
                        #print(scan)

                    scan_params = {}
                    scan_params["scan_info"] = scan
                    scan_params["dataset_yaml"] = dataset_yaml
                    scan_params["data"] =  data.copy(deep =True)
                    scan_params["p_bin_var"] = p_bin_var
                    scan_params["tilde_probability_var"] = tilde_probability_var
                    #scan_params["df_t"] = df_t.copy(deep = True)
                    #scan_params["df_copy"] = df_copy.copy(deep = True)
                    run_info["scan_params"] = scan_params
                    run_info_deep_copy = copy.deepcopy(run_info)
                    run_infos.append(run_info_deep_copy)
            
            df_copy = deep_df_copy.copy(deep =True)
            








        print("Number of runs: "+ str(len(run_infos)))

            ##sys.exit(0)

            # shuffle list 

        random.shuffle(run_infos)

            # running multi-parralel processing

        unsuccessful = False
    except Exception as e:
        print("running again : regenerating : error below")
        print(traceback.format_exc())
        


active_workers = 0
completed_workers = 0

active_processes_list = []
active_worker_constant = 20


#run_infos = run_infos[:69]
while (len(run_infos) >0) or (active_workers > completed_workers):
    print(len(run_infos) )

    # checking if there are less than 10 processes running
    if ((len(active_processes_list) < active_worker_constant) and ( len(run_infos) >0)):
        print(len(run_infos))

        needed_workers = active_worker_constant - len(active_processes_list)
            
        if (needed_workers > (len(run_infos))):
            needed_workers = len(run_infos)

        #data_copy = sample_df.copy(deep= True)

        print("Will create " + str(needed_workers) + " processes")

        new_workers = [mp.Process(target = run_wrapper, args = ([run_infos.pop()])) for x in range(0,needed_workers )]
            
        print(len(run_infos))

        for worker in new_workers:
            active_workers = active_workers + 1
            worker.start()

            print("starting worker " + str(active_workers))

        active_processes_list = active_processes_list + new_workers

    # sleep for 30 seconds

    print("sleeping for 30 seconds")

    time.sleep(30)

                                                                                                                                # check if workers are alive or not
    if (len(run_infos) == 0):
        print("all jobs are assigned.. waiting for all workers to complete")

    replacement_list = []

    for process in active_processes_list:
        if (process.is_alive() == True):
            replacement_list.append(process)
        else:
            print("there is a complete worker")
            completed_workers = completed_workers + 1
            print("total complete workers : " + str(completed_workers))
            if (process.exitcode != 0):
                print("there was an unsuccesfful run!")
                print(process.exitcode)
                    #sys.exit("issue with run")
                    
            process.terminate()

            print('ended completed process')


    active_processes_list = replacement_list
