import numpy as np
import time
import pandas as pd

from scan_distribution_specific_score_functions.bernoulli_scan import BernoulliScan
from scan_distribution_specific_score_functions.gaussian_scan import GaussianScan

class FSScan:
    
    def __init__(self, printing_on):
        self.printing_on = printing_on
        
    def get_aggregates(self,coordinates,p_hat, event, values_to_choose,column_name,penalty,direction):
        #print("Calling get_aggregates with column_name=",column_name)
        
        if values_to_choose:
            to_choose = coordinates[values_to_choose.keys()].isin(values_to_choose).all(axis=1)
            temp_df=pd.concat([coordinates.loc[to_choose], p_hat[to_choose],event[to_choose]],axis=1)
        else:
            temp_df= pd.concat([coordinates, p_hat,event],axis=1)
        aggregates = {}
        thresholds = set()

        #_, i = np.unique(temp_df.columns, return_index=True)
        #temp_df = temp_df.loc[:, i]

        for name, group in temp_df.groupby(column_name):
            #print("creating aggregrate for : "+ str(name))
            p_hat_select = group.iloc[:,-2]
            event_select = group.iloc[:,-1]
            positive, param_mle, param_min, param_max = self.scanner.compute_param(p_hat_select,event_select,penalty)
            #print("name=",name,"positive=",positive, "mu_mle=",mu_mle,"mu_min=",mu_min,"mu_max=",mu_max)
            if positive:
                if direction == 'positive':
                    if param_max < self.scanner.cut_off:
                        positive = 0
                    elif param_min < self.scanner.cut_off:
                        param_min = self.scanner.cut_off
                else: 
                    if param_min > self.scanner.cut_off:
                        positive = 0
                    elif param_max > self.scanner.cut_off:
                        param_max = self.scanner.cut_off
                if positive:
                    aggregates[name]={'positive':positive, 'param_mle':param_mle,'param_min':param_min,'param_max':param_max,'p_hat_select':p_hat_select, "event_select":event_select}
                    thresholds.update([param_min,param_max])

        #allsum = temp_df.iloc[:,-2].sum()
        p_hat = temp_df.iloc[:,-2]
        event = temp_df.iloc[:,-1]
        return [aggregates,sorted(thresholds),p_hat, event]
    
    #def search_for_param(self, p_hat, event):
    #    if self.sufficiency_prediction:
    #        return self.(self, events, p_hat, mu, penalty)
            

    def choose_aggregates(self,aggregates,thresholds,penalty,p_hat, event,direction):
        best_score = 0
        best_param = 0
        best_names = []
        for i in range(len(thresholds)-1):
            thethreshold = (thresholds[i]+thresholds[i+1])/2
            names = []
            the_p_hats = []
            the_events = []
            #print("current threshold to check for in choose aggregate: "+ str(thethreshold))
            for key, value in aggregates.items():
                # if score is positive
                if (value['positive']) & (value['param_min'] < thethreshold) & (value['param_max'] > thethreshold):
                    names.append(key)
                    the_p_hats = the_p_hats + value['p_hat_select'].tolist()
                    the_events = the_events + value["event_select"].tolist()
            the_p_hats_series = pd.Series(the_p_hats)
            the_events_series = pd.Series(the_events)
            current_param_mle = self.scanner.compute_param_mle(the_p_hats_series,the_events_series,penalty)
            if ((direction == 'positive') & (current_param_mle < self.scanner.cut_off)) | ((direction != 'positive') & (current_param_mle > self.scanner.cut_off)):
                current_param_mle = self.scanner.cut_off
            current_score = self.scanner.compute_score_given_param(the_events_series, the_p_hats_series,current_param_mle, penalty*len(names))
            #print "In choose_aggregates, current_score = ",current_score+penalty*len(names),"-",penalty*len(names),"=",current_score
            if current_score > best_score:
                best_score = current_score
                best_param = current_param_mle
                best_names = names
            #print 'current',names,current_score,current_q_mle,'with penalty of',penalty*len(names)
        # also have to consider case of including all attributes values including those that never make positive contributions to the score
        all_p_hat_series = pd.Series(p_hat)
        all_event_series = pd.Series(event)
        current_param_mle = self.scanner.compute_param_mle(all_p_hat_series, all_event_series, penalty)
        if ((direction == 'positive') & (current_param_mle < self.scanner.cut_off)) | ((direction != 'positive') & (current_param_mle > self.scanner.cut_off)):
            current_param_mle = self.scanner.cut_off
        current_score = self.scanner.compute_score_given_param(all_event_series,all_p_hat_series,current_param_mle,0)
        #print "In choose_aggregates, current_score = ",current_score,"-[no penalty]=",current_score
        if current_score > best_score:
            best_score = current_score
            best_param = current_param_mle
            best_names = []
        return [best_names,best_score,best_param]

    # ONLY CALLED ON FIRST ITERATION OF THE ALGORITHM

    # coordinates: features for women

    def mk_subset_all_values(self,coordinates):
        #print(coordinates)
        subset_all_values = {}
        for theatt in coordinates:
            subset_all_values[theatt]=coordinates[theatt].unique().tolist()
        #print("mk_subset_all_values return value:")
        #print(subset_all_values)
        return subset_all_values
    




    def mk_subset_random_values(self,coordinates,prob,minelements=0):
        subset_random_values = {}
        shuffled_column_names = np.random.permutation(coordinates.columns.values)
        for theatt in shuffled_column_names:
            temp = coordinates[theatt].unique()
            mask = np.random.rand(len(temp)) < prob
            if mask.sum() < len(temp):
                subset_random_values[theatt] = temp[mask].tolist()
                remaining_records = len(coordinates.loc[coordinates[subset_random_values.keys()].isin(subset_random_values).all(axis=1)])
                if remaining_records < minelements:
                    del subset_random_values[theatt]
        return subset_random_values
    
    

    


    def md_scan(self,coordinates,treatment_probs,treatment_equiv_probs,penalty,num_iters,direction,minelements=0):
        best_subset = {}
        best_score = -1E10
        best_mu = -10000000
        for i in range(num_iters): # [0,49]
            if self.printing_on:
                print("ITER: "+ str(i))
            flags = np.empty(len(coordinates.columns))

            flags.fill(0)
            # starting subset
            current_subset = self.mk_subset_all_values(coordinates) if i == 0 else self.mk_subset_random_values(coordinates,np.random.rand(),minelements)
            if self.printing_on:
                print("beginning subset of scan is: " + str(current_subset))
            current_score, current_param = self.scanner.score_current_subset(coordinates,treatment_equiv_probs,treatment_probs ,penalty,current_subset,direction)
            if self.printing_on:
                print("Starting subset with score of",current_score,":")
                print(current_subset)
                print(current_param)
            while flags.sum() < len(coordinates.columns):

                # choose random
                attribute_number_to_scan = np.random.choice(len(coordinates.columns))
                while flags[attribute_number_to_scan]:
                    attribute_number_to_scan = np.random.choice(len(coordinates.columns))
                attribute_to_scan = coordinates.columns.values[attribute_number_to_scan]

                #print('SCANNING:' + str(attribute_to_scan) + " and this is # " + str(attribute_number_to_scan))
                if attribute_to_scan in current_subset:
                    del current_subset[attribute_to_scan]# HERE!!! Now we must replace current_subset with temp_subset below
                #print("current subset after attribute is choosen: " + str(current_subset))

                aggregates,thresholds,p_hat, event = self.get_aggregates(coordinates,treatment_equiv_probs,treatment_probs,current_subset,attribute_to_scan,penalty,direction)

                temp_names,temp_score, temp_param=self.choose_aggregates(aggregates,thresholds,penalty,p_hat, event,direction)
                #print("choosen aggregate: ")
                #print("attribute names: " + str(temp_names))
                #print("attribute scores: " + str(temp_score))
                temp_subset = current_subset.copy()
                if temp_names: # if temp_names is not empty (or null)
                    temp_subset[attribute_to_scan]=temp_names
                temp_score, temp_param =  self.scanner.score_current_subset(coordinates,treatment_equiv_probs,treatment_probs,penalty,temp_subset,direction)
                
                if temp_score > current_score+1E-6:
                    flags.fill(0)
                elif temp_score < current_score-1E-6:
                    print("WARNING SCORE HAS DECREASED from",current_score,"to",temp_score)

                flags[attribute_number_to_scan] = 1            
                current_subset = temp_subset
                current_score = temp_score
                current_param = temp_param
            if self.printing_on:
                print("Subset found on iteration",i+1,"of",num_iters,"with score",current_score)
                print(current_subset)
                print(best_score)
            if (current_score > best_score):
                best_subset = current_subset.copy()
                best_score = current_score
                best_param = current_param
                #print "Best score is now",best_score
            #else:
                #print "Current score of",current_score,"does not beat best score of",best_score
        
        results = {}
        results["best_subset"] = best_subset
        results["best_score"] = best_score
        results["best_param"] = best_param
        return results

    def compare_control_subset(self,treatment, 
                               treatment_outcomes,
                               controls, 
                               control_outcomes,
                               subset,
                               brief=False):

        if (subset):
            treatment_to_choose = treatment[subset.keys()].isin(subset).all(axis=1)
            treatment_df = treatment[treatment_to_choose]
            treatment_outcomes_df = treatment_outcomes[treatment_to_choose]
            controls_to_choose = controls[subset.keys()].isin(subset).all(axis=1)
            control_df = controls[controls_to_choose]
            control_outcomes_df = control_outcomes[controls_to_choose]
        else:
            treatment_df = treatment
            treatment_outcomes_df = treatment_outcomes
            control_df = controls
            control_outcomes_df = control_outcomes        

        print(treatment_outcomes_df.count(),'treatment individuals with',100*treatment_outcomes_df.mean(),'% test positives')   
        print(control_outcomes_df.count(),'control individuals with',100*control_outcomes_df.mean(),'% test positives')

        for theatt in treatment:
            for thevalue in treatment[theatt].unique():
                print("testing specifically " + str(thevalue) + " for "+ str(theatt))
                count1a = treatment_df[treatment_df[theatt] == thevalue].iloc[:,0].count()
                count1b = treatment_df.iloc[:,0].count()-count1a
                count2a = control_df[control_df[theatt] == thevalue].iloc[:,0].count()
                count2b = control_df.iloc[:,0].count()-count2a

                odds, p = scs.fisher_exact([[count1a,count2a],[count1b,count2b]])
                if p < .05:
                    if brief:
                        print(theatt,thevalue,"<" if (count1a*1.0/(count1a+count1b))<(count2a*1.0/(count2a+count2b)) else ">")
                    else:
                        print("Attribute =",theatt,", Value =",thevalue,", p =",p,":",count1a,"of",count1a+count1b,"(",count1a*100.0/(count1a+count1b),"%)","vs.",count2a,"of",count2a+count2b,"(",count2a*100.0/(count2a+count2b),"%)")
                #else:
                #    print(f"Not significant p-value: {p}", "Attribute =",theatt,", Value =",thevalue)
        return




    def evaluate_scan(self, treatments,
                      treatment_probs,
                      treatment_equiv_probs,
                      controls,
                      control_probs,
                      scan_type,
                      direction,
                      subset_scan_penalty=0.1,
                      subset_scan_num_iters=50,
                      minelements=100):
        
        self.scan_type = scan_type
        print("penalty is: " + str(subset_scan_penalty))
        
        
        if self.scan_type == "prediction_separation":
            self.scanner = GaussianScan()
        else:

            self.scanner = BernoulliScan()

        # find and summarize highest-scoring positive subset
        start_time = time.time()
        results = self.md_scan(treatments,treatment_probs,treatment_equiv_probs,subset_scan_penalty,subset_scan_num_iters,direction,minelements)
        time_in_seconds = time.time()-start_time
        if self.printing_on:
            print("Required time = ",time_in_seconds,"seconds")
            print("best score: " + str(results["best_score"]))
            print("best subset: " + str(results["best_subset"]))
            print("best param: " + str(results["best_param"]))
        results["required_time"] = time_in_seconds
        results["treatment"] = treatments
        results["treatment_events"] = treatment_probs
        results["treatment_p_hat"] = treatment_equiv_probs
        results["controls"] = controls
        results["control_events"] = control_probs
            
        return results