import numpy as np; import pylab as plt; import time

class parameter_estimation():
    def __init__(self):
        self.A = 0                                                                                                                                                        # network action 
        self.a = .1                                                                                                                                                       # p(safe)
        self.a_0 = 1 - self.a                                                                                                                                             # 1 - p(safe)
        self.bayes_resolution = 50                                                                                                                                        # resolution of bayesian distributions
        self.max_mem = 100                                                                                                                                                # Free parameter goes from 1 to max mem
        self.data_len = 1000                                                                                                                                              # total number of trials 
        self.trial_dur = 100                                                                                                                                              # max duration of a single trial        
        # useful vectors 
        self.padding = np.ones(self.max_mem)                                                                                                                              # padding used as stimuli until step is larger than memory length  
        self.mem_range = 1 + np.arange(self.max_mem)                                                                                                                      # vector going from 1 to free parameter length (e.g. mem length)
        self.bayes_range = np.linspace(0, 1, self.bayes_resolution)                                                                                                       # vector of every possible world of PGO for full distribution 
        self.discount_forget = np.linspace(.8, 1, self.max_mem)[:, None]**(self.mem_range[None, :]-1)                                                                     # vector of every possible discounted past forgetting rate 
        
    def reset(self):
        # storage matrix initializations 
        self.curr_Psafe, self.curr_Punsafe, self.curr_PGO, self.curr_prior_weight, self.curr_model_free = [np.ones((self.max_mem, self.trial_dur)) for _ in range(5)]     # Dims: [ Free parameter X max trial duration] = [ Mem , Max ]
        self.Psafe_log, self.PGO_log, self.prior_weight_log, self.model_free_log = [np.empty(self.data_len, dtype = object) for _ in range(4)]                            # Dims: [ number of trials ] = [ Trials ]
        self.steps, self.act_times, self.rews = [np.zeros(self.data_len) for _ in range(3) ]                                                                              # Dims: [  number of trials ] = [ Trials ]
        self.curr_n, self.running_stim = [np.zeros((self.max_mem, 1)) for _ in range(2)]                                                                                  # Dims: [ Free parameter ] = [ Mem ]       
        self.bayes_dist = np.ones((self.max_mem, self.bayes_resolution))                                                                                                  # Dims: [ Free parameter X bayes dist resolution] = [ Mem , Bayes ]
        self.weights = np.ones((self.max_mem, self.max_mem))                                                                                                              # Dims: [ Free parameter X past m stimuli ] = [ Mem , Mem ]
        self.PGO_flat, self.backbone_flat = [[], []]                                                                                                                      # initialize lists of dynamic length
        self.bayes_step = self.known = self.unsafe = self.loaded_n = 0                                                                                                    # initilalizations of several variables

    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    Trial generation and visualization
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    
    def make_trials(self, p_switch = .01):
        self.sample_PGO()                                                                                                                                                 # sample an initial PGO  
        self.durs = np.random.exponential(int(1/self.a), size = (self.data_len,1) )                                                                                       # generate stimulus durations for all trials
        for self.trial in range(self.data_len-1):                                                                                                                         # loop through trials
            if np.random.rand() < p_switch:                                                                                                                               # if new block 
                self.sample_PGO()                                                                                                                                         # resample PGO
            self.run_trial()                                                                                                                                              # create trial 
        self.backbone_flat = np.array(self.backbone_flat)                                                                                                                 # turn dynamic list into flat array of all stimuli
        self.trial_ends = self.steps.astype(int) - 1                                                                                                                      # turn float array into int type
        
    def sample_PGO(self): 
        self.PGO = round(np.random.rand(), 1)                                                                                                                             # generate PGO

    def run_trial(self, policy = .05, ongoing = True):                                                                                                                                     
        while ongoing:                                                                                                                                                    # while stimulus is ongoing 
            stim = 1 if self.steps[self.trial]  > self.durs[self.trial] else np.random.binomial(1, self.PGO)                                                              # stimulus is 1 if SAFE else sample from PGO  
            ongoing = (self.steps[self.trial]  < self.trial_dur - 5) and (np.random.rand() > policy)                                                                      # trial is ongoing if shorter than max duration (with buffer to avoid indexing bugs) and no action
            self.backbone_flat.append(stim)                                                                                                                               # record current stimuli
            self.PGO_flat.append(self.PGO)                                                                                                                                # record current PGO
            self.steps[self.trial] = self.steps[self.trial] + 1                                                                                                           # record current trial steps
        self.act_times[self.trial] = self.steps[self.trial]                                                                                                               # record the action time (end of trial)
        self.rews[self.trial] = self.steps[self.trial] > self.durs[self.trial]                                                                                            # record if trial ended after UNSAFE state

    def visualize_trials(self, mem = 20, smoothing = 50):                                                                                                                           # plotting generated data 
        plt.figure(figsize =(40,5))
        nogos = np.where(self.backbone_flat == 0)[0]
        gos = np.where(self.backbone_flat == 1)[0]
        plt.scatter(nogos, np.zeros(len(nogos)), marker = '.', c = 'C0', alpha = .5)
        plt.scatter(gos, np.ones(len(gos)), marker = '.',  c = 'C0', alpha = .5)
        plt.plot(self.backbone_flat, alpha = .3, lw = .25, c ='C0')
        plt.plot(self.PGO_flat, linewidth = 10, c = 'C1', alpha = .5)
        plt.plot(np.convolve(self.backbone_flat/smoothing, np.ones(smoothing)), c = 'C2')
        plt.plot(self.PGO_estimates[mem, :], c = 'C3', alpha = .5)
        plt.legend(['nogo', 'go', 'stim', 'PGO', f'avg stim (window size {smoothing})', 'bayesian estimate'])
        plt.title(f"{self.bayes_type} Free Parameter {mem}")
        plt.show()
        
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    Trigger function
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    
    def get_bayes(self, bayes_type = None):
        t = time.time()
        self.reset()                                                                                                                                                      # initialize new storage vectors
        self.bayes_type = bayes_type
        self.make_trials()                                                                                                                                                # generate trials 
        for self.curr_trial in range(self.data_len):                                                                                                                      # for each trial
            self.end = self.trial_ends[self.curr_trial] + 1 
            self.bayes_loop()                                                                                                                                             # run bayesian algorithm on trial 
            self.bayes_log()                                                                                                                                              # store bayes results
        self.PGO_estimates = np.hstack(T for T in self.PGO_log)                                                                                                           # make log convinient for plotting
        self.visualize_trials()                                                                                                                                           # visualize created trials
        print(time.time()-t)
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    bayes processing
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    
    def bayes_log(self):
        self.model_free_log[self.curr_trial] = self.curr_model_free[:, :self.end].copy()                                                                                  # record model free window avg stim
        self.Psafe_log[self.curr_trial] = self.curr_Psafe[:, :self.end].copy()                                                                                            # record SAFE state belief 
        self.PGO_log[self.curr_trial] = self.curr_PGO[:, 1:self.end+1].copy()                                                                                             # record PGO estimate
        self.curr_PGO[:, 0] = self.curr_PGO[:,self.end].copy()                                                                                                            # start next trial estimate with curr trial's end estimate
        if self.curr_trial % int(self.data_len/10) == 0:                                                                                                                  
            print(str(int(100 * self.curr_trial/self.data_len)) + "%")                                                                                                    # print percentage complete
                
    def bayes_loop(self):
        for self.curr_step in range(self.end):                                                                                                                            # for each step in current trial
            self.Inference_step()                                                                                                                                         # perform inference of state beliefs 
            self.Estimation_step()                                                                                                                                        # perform estimation of parameters
            self.bayes_step += 1 
            self.unsafe += 1 
            self.known = np.clip(self.known + 1, a_min = None, a_max = self.max_mem - 1)                                                                                  # clamp the number of known steps by the maximum free param

    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    bayes calculations
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    
    """ Inference step """
    def Inference_step(self):
        self.update_n()                                                                                                                                                   # update number of consecutive gos
        self.update_b_c()                                                                                                                                                 # update inference variables
        self.update_psafe()                                                                                                                                               # update safe state probability 
        
    def update_n(self):
        self.stim_1 = self.backbone_flat[self.bayes_step] == 1                                                                                                            # get if curr stim is GO 
        self.curr_n = np.minimum(self.known, self.curr_n)                                                                                                                 # cap consec GOs by last known
        self.curr_n = (1-self.A)*self.stim_1*(self.curr_n + 1)                                                                                                            # increment if curr GO
        self.curr_n = np.minimum(self.mem_range[:, None], self.curr_n)                                                                                                    # cap consec GOs by free param (usually mem size)
        
    def update_b_c(self):     
        self.b = self.a_0*self.curr_PGO[:, self.curr_step, None].copy()                                                                                                   # use curr estimate of PGO 
        self.b_n = (self.b**self.curr_n)                                                                                                                                  # b^n
        self.b_0 = 1 - self.b                                                                                                                                             # 1 - b
        self.b_n_0 = 1 - self.b_n                                                                                                                                         # 1 - b^n 
        self.b_sum_to_n_min_1 = self.b_n_0 / self.b_0                                                                                                                     # sum over i from 0 -> n-1 of b^i = (1-b^n) / (1-b)
        n_min_k = np.maximum(0, self.curr_n - np.arange(self.max_mem))                                                                                                    # create vector of n, n-1, n-2 .... 0 
        self.b_k_min_n = self.b**(n_min_k) - self.b**self.curr_n                                                                                                          # subtract b^n from sum through all possible consec GOs 
        self.c = self.a_0*self.b_0                                                                                                                                        # (1-a) * (1-b)

    def update_psafe(self):
        self.curr_Punsafe[:, self.curr_step] = (self.b_n/(1 - self.c * self.b_sum_to_n_min_1)).squeeze()                                                                  # calculate psafe = b^n / ( 1 - c * (1 - b^n)/(1-b))
        self.curr_Psafe[:, self.curr_step] = (1 - self.curr_Punsafe[:, self.curr_step])                                                                                   # get Psafe as a function of n and current PGO estimate  
        
    """ Estimation step """
    def Estimation_step(self):
        stim =  self.backbone_flat[self.bayes_step - self.max_mem+1 : self.bayes_step+1] if self.bayes_step > self.max_mem else self.padding                              # get window of stim (or padding for begining of process)
        self.last_m_stim = np.flip(stim)                                                                                                                                  # reverse order to get [new ... old] for matrix mulitiplication convinience 
        self.check_for_action()                                                                                                                                           # check if action occured
        self.update_PGO()                                                                                                                                                 # Handle probabilistic weighting and PGO update 
        if not self.stim_1 or self.A:                                                                                                                                     # If NOGO or action 
            self.known = 0                                                                                                                                                # all states are known, reset known index to 0
            
    def check_for_action(self):
        self.A = self.curr_step == self.act_times[self.curr_trial]                                                                                                        # Check if action was taken at curr step 
        self.R = self.rews[self.curr_trial] == 1                                                                                                                          # Check if reward was recieved at curr trial 
            
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    bayes specifics
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    
        
    def update_PGO(self):       
        """ 
        bayes types:
            
            standard =              simple memory window with weights for each stim in memory, for sufficient statistic
            
            soft =                  expert bayesian updating only after NOGO, for full posterior distribution 
            
            discount =              standard + decaying vector multiplied with the weights, for gradual forgetting
            
            hybrid =                weighted likelihood of each stim in memory, for full posterior distribution
                
            dynamic =              soft forgetting with dynamic discounting based on the estimated uncertainty
        
        """
        if self.bayes_type == 'standard':
            PGO_update, PGO_prior = self.standard_estimate()
        if self.bayes_type == 'soft': 
            PGO_update, PGO_prior = self.soft_estimate()
        if self.bayes_type == 'discount':
            PGO_update, PGO_prior = self.discount_estimate()
        if self.bayes_type == 'hybrid':
            PGO_update, PGO_prior = self.hybrid_estimate()
        if self.bayes_type == 'dynamic':
            PGO_update, PGO_prior = self.dynamic_estimate()
            
        self.curr_model_free[:, self.curr_step] = np.cumsum(self.last_m_stim)/self.mem_range                   # model free estimate is just the average stim for each window length 
        self.curr_PGO[:, self.curr_step + 1] =  PGO_update + PGO_prior                                         # Update curr estimate (weighted average) of PGO 
        if np.any(np.isnan(self.curr_PGO[:, self.curr_step])):                                                 # if anything fucked up print it out 
            print("nan present")
            
    def standard_estimate(self):
        self.update_weights()                                                                                  # Update weights         
        stim_weights, prior_weights = self.normalize_weights(self.weights)                                     # normalize weights 
        PGO_update = stim_weights @ self.last_m_stim                                                           # weighted avg of last m stim: [Mem , Mem ] X [ Mem ] = [ Mem ], is PGO update for each value of free paramter (e.g. memory duration)
        PGO_prior = prior_weights * self.curr_PGO[:, self.curr_step]                                           # remainder of weight put into current prior
        return PGO_update, PGO_prior
            
    
    def soft_estimate(self, eps = 1e-5):
        self.update =  1 if self.curr_n[-1] > 0 else (self.bayes_range**self.loaded_n)*(1 - self.bayes_range)  # if NOGO, likelihoods are equal to sigma ^ n * ( 1 -sigma ) , else each likelihood is equal (assumed safe)
        self.loaded_n = self.curr_n[-1]                                                                        # prepare next step's n (since curr_n will be 0 if NOGO)
        self.bayes_dist = self.norm_by_sum( self.bayes_dist * self.update + eps, dim = -1)                     # normalize the posterior with an epsilon decay to uniform distributio
        PGO_update =  (self.bayes_dist * self.bayes_range).sum(-1)                                             # Update to the curr estimate of PGO (prior weight is 0 because posterior already used prior)
        return PGO_update, 0
    
    def discount_estimate(self):
        self.update_weights()                                                                                  # Update weights   
        W = self.weights * self.discount_forget                                                                # gradual forgetting
        norm = self.discount_forget.sum(-1, keepdims= True)                                                    # get norm for weights 
        stim_weights =  W/norm                                                                                 # normalize weights
        prior_weights = (1-stim_weights.sum(-1))                                                               # get weight of prior 
        PGO_update = stim_weights @ self.last_m_stim                                                           # weighted avg of last m stim: [Mem , Mem ] X [ Mem ] = [ Mem ], is PGO update for each value of free paramter (e.g. memory duration)
        PGO_prior = prior_weights * self.curr_PGO[:, self.curr_step]                                           # remainder of weight put into current prior 
        return PGO_update, PGO_prior
    
    def hybrid_estimate(self, eps = 1e-5):                                                                     # hybrid has both memory (either hard cutoff or discounted forgetting) and distributions 
        self.update_weights()                                                                                  # Update weights   
        
        """"""""""""""""""""" 2 options to choose from """""""""""""""""""""
        # W = (self.weights * np.tri(self.max_mem))[:,:,None]                                                    # hard cutoff version of weights 
        W =  (self.weights * self.discount_forget)[:,:,None]                                                     # gradual forgetting version of weights 

        """"""""""""""""""""" 2 options to choose from """""""""""""""""""""
        
        S = self.last_m_stim[:, None]                                                                          # get last M stim
        B = self.bayes_range[None, :]                                                                          # get every possible PGO 
        likelihood = (S*B + (1-S)*(1-B) )[None,:,:]                                                            # the likelihood of each individual stim is PGO if GO or (1 - PGO) if NOGO = GO * PGO + ( 1 - GO ) * ( 1 - PGO )
        update = self.norm_by_sum( ( likelihood ** W ).prod(1)  + eps, dim = -1)                               # likelihood: multiply likelihood of each stim, to the power of the weight, e.g. weight = 0 means all PGO are equally probable
        self.bayes_dist = self.norm_by_sum(self.bayes_dist * update + eps, dim = -1)                           # get posterior distribution
        PGO_update =  (self.bayes_dist * self.bayes_range[None, :]).sum(-1)                                    # current estimate of PGO for all memory durations
        return PGO_update, 0

        
    def dynamic_estimate(self):    
        self.update_weights()                                                                                  # Update weights   
        
        """"""""""""""""""""" 3 options to choose from """""""""""""""""""""
        # reliability =  1-.5*self.curr_PGO[:, self.curr_step, None]                                           # discount factor decreases with PGO
        reliability =  1 - np.abs(self.curr_PGO[:, self.curr_step, None] - .5)                                 # discount factor is maximal at .5 
        # reliability =  .5 + np.abs(self.curr_PGO[:, self.curr_step, None] - .5)                              # discount factor is minimal at .5
        """"""""""""""""""""" 3 options to choose from """""""""""""""""""""

        self.discount_forget = reliability**(self.mem_range[None, :]-1)                                        # compute dynamic discount factor for each memory duration
        W = self.weights * self.discount_forget                                                                # implement discounted forgetting
        norm = self.max_mem                                                                                    # the norm of the maximum memory such that more forgetting = more weight to prior 
        stim_weights =  W/norm                                                                                 # normalize weights
        prior_weights = (1-stim_weights.sum(-1))                                                               # get prior weight for each memory duration
        PGO_update = stim_weights @ self.last_m_stim                                                           # get weighted average of stim
        PGO_prior = prior_weights * self.curr_PGO[:, self.curr_step]                                           # get weighted prior 
        return PGO_update, PGO_prior
    
    
    def update_weights(self):
        self.weights = np.roll(self.weights, 1)                                                     # Shift weights to preserve weights after last known
        self.weights[:, -1] = 0
        num = self.b_k_min_n                                                                        # Action + reward case numerator 
        denom = self.b_n_0                                                                          # Action + reward case denominator 
        if not self.A:                                                                              # If no action
            b_0_b_n = (self.b_0*self.b_n )                                                          # (1-b) * b^n 
            num = b_0_b_n + self.a*num                                                              # No action case numerator 
            denom = b_0_b_n + self.a*denom                                                          # No action case denominator 
        W = num/(1e-20+denom)
        if not self.stim_1 or (self.A and not self.R):                                              # If NOGO or action and no reward
            self.unsafe = 0                                                                         # Current state is unsafe, reset unsafe index to 0 
        W[:, self.unsafe:] = 1                                                                      # Update weights from last known unsafe to be 1 
        self.weights[:, :self.known] =  W[:, :self.known]                                           # Return unsafe probabilities up to where already known 
                
    def normalize_weights(self, weights):
        weights = weights * np.tri(self.max_mem)                                                    # Set upper triangle to zero to implement memory lengths 
        prior_weights = (1 - weights)* np.tri(self.max_mem)                                         # prior weight comes from any weights less than 1
        prior_weights = (prior_weights).sum(-1)[:,None]                                             # get total prior weight
        denom = self.mem_range[:, None]                                                             # Normalize weights to sum to 1 
        return weights/denom, (prior_weights/denom).squeeze()  
 
    def norm_by_sum(self, A, dim):
        return A / A.sum(dim, keepdims = True)   
 
if __name__ == "__main__":
    bayes = parameter_estimation()
    """ to make and visualize trials without running the bayesian estimation """ 
    # bayes.make_trials()                                                               
    # bayes.visualize_trials()
    bayes.get_bayes(bayes_type = 'standard')
    bayes.get_bayes(bayes_type = 'soft')
    bayes.get_bayes(bayes_type = 'discount')
    bayes.get_bayes(bayes_type = 'hybrid')
    bayes.get_bayes(bayes_type = 'dynamic')