
import yaml
import pandas as pd
import numpy as np

class DatasetSpecificFuncs:
    
    def __init__(self, dataset_yaml_file, tilde_probability_var, p_bin_var):
        
        self.dataset_yaml_file = dataset_yaml_file
        self.tilde_probability_var = tilde_probability_var
        self.p_bin_var = p_bin_var
        
    
    def run(self, data=None):
        #1. read in yaml file
        
        with open(self.dataset_yaml_file , 'r') as file:
            self.dataset_yaml = yaml.safe_load(file)
            
        
        #2. read in data
        if isinstance(data, type(None)) == True:
            self.data = pd.read_csv(self.dataset_yaml["data_file"])
        else:
            self.data = data
            
        # pull out dataset specific field names
        
        
        self.observed_label = self.dataset_yaml["observed_label"]
        
        
        #3. run specific yaml settings
        
        if self.dataset_yaml["calculate_mle_prob"]:
            #self.tilde_probability_var = tilde_probability_var
            self.score = self.dataset_yaml["score"]
            self.calc_mle_probability(self.score, self.observed_label)
            
        else:
            pass 
            # need to figure out what to do if there is a estimated probability (tilde_p) or none at all (ie: NYPD compas scan)
            #self.tilde_probability_var = self.dataset_yaml["score"]
        
        #if self.dataset_yaml["threshold_mle_prob"]:
        #    self.data[self.p_bin_var] = np.where(self.data[self.tilde_probability_var] >= self.dataset_yaml["threshold"], 1, 0)
        
        return self.data, self.dataset_yaml, self.tilde_probability_var, self.p_bin_var
            
    
    def calc_mle_probability(self, compas_score, reoffended_var ):
        risk_score_to_probability_mapping = self.data.groupby(compas_score).mean()[reoffended_var].to_dict()
        self.data[self.tilde_probability_var] = self.data[compas_score].map(risk_score_to_probability_mapping)
        
        del self.data[compas_score]