import pandas as pd
import numpy as np

class BernoulliScan:
    
    def __init__(self):
        self.cut_off = 1
    
    def score_current_subset(self, coordinates,p_hat,event,penalty,current_subset,direction='positive', q_given=None):
        #print("info about current subset")
        #print(current_subset)
        #bool_current_subset = if current_subset
        #print(bool_current_subset)

        if current_subset:     # if current subset is not empty
            #print("printing chosen dataframe info")
            to_choose = coordinates[current_subset.keys()].isin(current_subset).all(axis=1) 
            # ^^^^
            # start off my choosing all rows
            #print(to_choose)
            temp_df=pd.concat([coordinates.loc[to_choose], event[to_choose], pd.Series(data=p_hat[to_choose],index=event[to_choose].index,name='p_hat')],axis=1)
            #print(temp_df)
        else: #if current subset is empty
            # use all columns and values possible
            temp_df= pd.concat([coordinates, event, pd.Series(data=p_hat,index=event.index,name='p_hat')],axis=1)
        thesum = temp_df.iloc[:,-2].sum() # total # of people 
        theprobs = temp_df.iloc[:,-1]
        totalpenalty = 0
        for i in current_subset.values():
            totalpenalty += len(i)
        totalpenalty *= penalty 
        if q_given == None:
            current_q_mle= self.binary_search_on_slopeterm(thesum,theprobs)
            if ((direction == 'positive') & (current_q_mle < 1)) | ((direction != 'positive') & (current_q_mle > 1)):
                current_q_mle = 1
            # totalpenalty = penalty * sum of list lengths in current_subset 

            #if q_given == None:
            penalized_score = self.compute_score_given_q(thesum,theprobs,totalpenalty,current_q_mle)
            #print "In score_current_subset, current_score = ",penalized_score+totalpenalty,"-",totalpenalty,"=",penalized_score
            return penalized_score,current_q_mle
        else:
            penalized_score = self.compute_score_given_q(thesum,theprobs,totalpenalty,q_given)
            return penalized_score, q_given
            
    
    def binary_search_on_score_for_q_min(self,thesum,theprobs,penalty,q_mle):
        q_temp_min = 0.000001
        q_temp_max = q_mle
        while np.abs(q_temp_max-q_temp_min) > 0.000001:
            q_temp_mid = (q_temp_min+q_temp_max)/2
            if np.sign(self.compute_score_given_q(thesum,theprobs,penalty,q_temp_mid)) > 0:
                q_temp_max = q_temp_max-(q_temp_max-q_temp_min)/2
            else:
                q_temp_min = q_temp_min+(q_temp_max-q_temp_min)/2
        #print("finding q_min, q_min: "+ str((q_temp_min+q_temp_max)/2) + " and score_bias :" + str(self.compute_score_given_q(thesum,theprobs,penalty,(q_temp_min+q_temp_max)/2)))
        return (q_temp_min+q_temp_max)/2

    def binary_search_on_score_for_q_max(self,thesum,theprobs,penalty,q_mle):
        q_temp_min = q_mle
        q_temp_max = 1000000.0
        while np.abs(q_temp_max-q_temp_min) > 0.000001:
            q_temp_mid = (q_temp_min+q_temp_max)/2
            if np.sign(self.compute_score_given_q(thesum,theprobs,penalty,q_temp_mid)) > 0:
                q_temp_min = q_temp_min+(q_temp_max-q_temp_min)/2
            else:
                 q_temp_max = q_temp_max-(q_temp_max-q_temp_min)/2
        #print("finding q_max, q_max: "+ str((q_temp_min+q_temp_max)/2) + " and score_bias :" + str(self.compute_score_given_q(thesum,theprobs,penalty,(q_temp_min+q_temp_max)/2)))
        return (q_temp_min+q_temp_max)/2

    # q_mle = argmax_q thesum*\ln(q) - \sum_{p_i \in theprobs} \ln(1-p_i+qp_i) + penalty 
    # [q_min, q_max] = q: thesum*\ln(q) - \sum_{p_i \in theprobs} \ln(1-p_i+qp_i) + penalty = 0, if these exist
    # positive = 1 if q_min, q_max exist, 0 otherwise
    
    def compute_param(self,theprobs,theevents,penalty):
        thesum = theevents.sum()
        q_mle = self.binary_search_on_slopeterm(thesum,theprobs)
        if self.compute_score_given_q(thesum,theprobs,penalty,q_mle) > 0:
            positive = 1
            q_min = self.binary_search_on_score_for_q_min(thesum,theprobs,penalty,q_mle)
            q_max = self.binary_search_on_score_for_q_max(thesum,theprobs,penalty,q_mle)
        else:
            positive = 0
            q_min = 0
            q_max = 0
        return positive, q_mle, q_min, q_max
    
        # this actually computes q times the slope, which has the same sign as the slope

    def compute_slopeterm_given_q(self,thesum,theprobs,q):
        theprobs.apply(lambda x: q*x/(1-x+q*x))
        return thesum-theprobs.apply(lambda x: q*x/(1-x+q*x)).sum()
    
    def binary_search_on_slopeterm(self,thesum,theprobs):
        q_temp_min = 0.000001
        q_temp_max = 1000000.0
        while np.abs(q_temp_max-q_temp_min) > 0.000001:
            q_temp_mid = (q_temp_min+q_temp_max)/2
            #print("slope")
            #print(compute_slopeterm_given_q(thesum,theprobs,q_temp_mid))
            if np.sign(self.compute_slopeterm_given_q(thesum,theprobs,q_temp_mid)) > 0:
                q_temp_min = q_temp_min+(q_temp_max-q_temp_min)/2
            else:
                q_temp_max = q_temp_max-(q_temp_max-q_temp_min)/2
        return (q_temp_min+q_temp_max)/2

    def compute_score_given_q(self,thesum,theprobs,penalty,q):
        if (q <= 0):
            print("Warning: calling compute_score_given_q with thesum=",thesum,"theprobs of length",len(theprobs),"penalty=",penalty,"q=",q)
        return thesum*np.log(q)-np.log(1-theprobs+q*theprobs).sum() - penalty

    def compute_param_mle(self, p_hat, event,penalty):
        return self.binary_search_on_slopeterm(event.sum(), p_hat)
    
    def compute_score_given_param(self, the_event_series, the_p_hat_series, current_param_mle, penalty):
        return self.compute_score_given_q(the_event_series.sum(), the_p_hat_series, penalty, current_param_mle)
    
    