import numpy as np
from tqdm import tqdm; from enzyme.src.mouse_task.trajectory_processing import trajectory_processing;  from sklearn.linear_model import LinearRegression; import time; import pylab  as plt

class bayesian_analysis(trajectory_processing):
    def __init__(self, **params):
        self.__dict__.update(params)
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    Bayesian Flow 
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    def get_bayes_flow(self, factorize = False):
        self.factorize = factorize
        self.init_bayes_flow()
        self.init_distributional_vars()
        self.infer_theta()

    def init_bayes_flow(self):
        self.bayes_theta_backbone, self.bayes_belief_backbone = [.5*np.ones(self.num_steps) for _ in range(2)]
        self.step_inputs = self.to(self.flatten_trajectory(self.data.net_input), 'np')
        self.stim_traj = self.step_inputs[:2]
        self.act_traj = self.step_inputs[3]
        self.rew_traj = self.step_inputs[4]
     
    def init_distributional_vars(self, eps = 2e-5):
        self.T_since_action = 0 
        binary = np.array([0,1])
        self.joint_dist = np.ones((3, self.bayes_resolution))
        self.joint_dist_log = np.zeros((3, self.bayes_resolution, self.num_steps))  

        # staying in the ITI after reward r
        self.T_r = np.zeros((3,3))
        self.T_r[2] = 1        

        self.T_theta = np.eye(self.bayes_resolution)       
        self.T_s = np.array([[1-1/self.exp_mean, 0, 1/self.ITI_mean],[1/self.exp_mean, 1, 0], [0,0, 1 - 1/self.ITI_mean]])                                                                     # dims: states x PGO x inputs
        self.P_X__s_theta = np.zeros((3, self.bayes_resolution, 2))                                                                                
        self.P_X__s_theta[:,:,:] = binary[None,:]*self.bayes_range[:,None] + (1-binary)[None:]*(1-self.bayes_range[:,None])       
        self.P_X__s_theta[1] = binary      
        """ softenings """ 
        self.T_theta = self.T_theta + 2e-5 
        self.T_theta = self.T_theta/self.T_theta.sum(0, keepdims=True)
        
        # theta_mean = 1/self.trial[-1] if self.theta_traj is None else self.PGO_N/(len(self.theta_traj))
        # self.T_theta = self.T_theta - (theta_mean - theta_mean/(self.bayes_resolution-1))*np.eye(self.bayes_resolution) +  theta_mean/(self.bayes_resolution-1)
        
        self.P_X__s_theta = self.P_X__s_theta + 1e-2 
        self.P_X__s_theta = self.P_X__s_theta/self.P_X__s_theta.sum(-1, keepdims=True)
        """ factorizing """ 
        self.P_X__s_only = self.P_X__s_theta.mean(1)                                                                                 # dims: states x NOGO/GO 
        self.P_X__theta_only = self.P_X__s_theta.mean(0)                                                                             # dims: theta x NOGO/GO 
           
  
    """ get emperical likelihood of observations and roll for convolution with memory """     
    # def get_emperical_likelihood(self, i = 1, consec_state = 0):
    #     while i < int(self.num_steps - self.window_len):
    #         s, e, t = self.get_PM_mem(i)
    #         PGO_ind = self.where(self.PGO_backbone[t], self.PGO_range)[0]
    #         STATE_ind = int(self.safe_backbone_flat[t])
    #         self.emperical_PX[i, :] = self.stim_traj[0, s:e][self.half_mem:-self.half_mem]

    #         consec_state = (consec_state + 1) * (self.safe_backbone_flat[t-1] == STATE_ind)
    #         discounting = self.temporal_discount**consec_state
    #         self.emperical_count[STATE_ind, PGO_ind] += discounting
    #         self.emperical_likelihood[STATE_ind, PGO_ind, :, :] += self.stim_traj[:, s:e] * discounting
            
    #         i += 1 
    #     self.emperical_likelihood = self.emperical_likelihood / self.emperical_count[:,:,None, None]

    # def get_PM_mem(self, i):
    #     t = i + self.max_mem
    #     s =  t - self.max_mem
    #     e = t + self.max_mem
    #     return s, e, t
    
    # def roll_likelihood(self):
    #     for i in range(self.max_mem):
    #         s = self.max_mem - i 
    #         e = s + self.max_mem
    #         self.rolled_likelihood[:, :, :, :, i] = self.emperical_likelihood[:,:, :,s:e]
        
        
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    infer theta given observations and computed likelihood 
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    def infer_theta(self):
        # self.unsafe_beliefs, self.safe_beliefs = [np.ones(self.max_mem) for _ in range(2)]
        # self.known = -1; self.consec = 0
        for self.i in tqdm(range(self.num_steps), desc=f"inferring bayesian estimates and theta distrbution (factorize = (self.factorize))", disable=False):
            self.take_step_forward()
            self.get_theta_distributional()        
            # if (self.i > self.window_len) and (self.i < (self.num_steps - self.window_len)):
                # self.get_belief_flow()
                # self.get_theta_weighted_avg()        
                # self.get_theta_mem_bayes()          
        
        if self.factorize:
            self.factorized_theta_structured = self.raw_to_structured(self.bayes_theta_backbone)
            self.factorized_belief_structured = self.raw_to_structured(self.bayes_belief_backbone)
        else:
            self.flow_theta_structured = self.raw_to_structured(self.bayes_theta_backbone)
            self.flow_belief_structured = self.raw_to_structured(self.bayes_belief_backbone)
            self.theta_dist_structured = self.raw_to_structured(self.joint_dist_log.sum(0), time_dim = 2)
            self.belief_dist_structured = self.raw_to_structured(self.joint_dist_log.sum(1), time_dim = 2)
        del(self.joint_dist_log)
            
    def take_step_forward(self):
        # self.stim = self.stim_traj[:, self.i - self.max_mem + 1 : self.i+1]
        # self.experience = np.tile(self.stim[:, :, None], self.max_mem)                                                           # dims: inputs x duplication all elements in memory x likelihood window chunk
        # theta_distances = np.abs(self.bayes_theta_backbone[self.i-1]-self.PGO_range)
        # self.theta_est_ind = np.argmin(theta_distances)  
        self.curr_stim = int(self.stim_traj[1, self.i])
        self.curr_act = int(self.i - 1 in self.cum_act_times) 
        self.curr_rew = int(self.rew_traj[self.i])
        # self.curr_rew = int(self.rew_traj[min(self.num_steps - 1, self.i + 1)])

        # self.NOGOs = self.stim[0]
        # self.GOs = self.stim[1]
        # self.PX = np.sum(np.all(self.emperical_PX == self.NOGOs, axis=1))/len(self.emperical_PX)
        
    # def get_belief_flow(self, eps = 1e-5, integrated_belief = 0):
    #     self.unsafe_beliefs, self.safe_beliefs = [np.ones(self.max_mem) for _ in range(2)]
    #     for i in range(self.max_mem):
    #         S = int(self.where(self.stim[:,i]==1))
    #         self.unsafe_beliefs = self.unsafe_beliefs * self.rolled_likelihood[0,  self.theta_est_ind, S, i, :]
    #         self.safe_beliefs = self.safe_beliefs * self.rolled_likelihood[1,  self.theta_est_ind, S, i, :]
        
    #     norm = (self.safe_beliefs + self.unsafe_beliefs)
    #     self.unsafe_beliefs = (self.unsafe_beliefs/norm)#.round()
    #     self.safe_beliefs = (self.safe_beliefs/norm)#.round()

    #     self.bayes_belief_backbone[self.i] = -np.log(self.safe_beliefs[-1])

    def get_theta_distributional(self, eps = 1e-8):                      
        if self.curr_act:
            self.handle_bayes_action()

        self.update_joint()   
        p_state = self.joint_dist.sum(-1)
        self.bayes_theta_backbone[self.i] = (self.bayes_range[None, :] * self.joint_dist).sum()   
        self.bayes_belief_backbone[self.i] = -np.log(1-np.clip(p_state[1], a_min = eps, a_max = 1-eps))
        self.joint_dist_log[:,:,self.i] = self.joint_dist
        
    def handle_bayes_action(self):
        # if there was reward, we zero the unsafe state
        # for no reward, we zero the safe and ITI state
        # afterward, T_r transition within ITI
        self.joint_dist = self.T_r @ (self.r_mask(self.curr_rew) * self.joint_dist)
        self.joint_dist = self.joint_dist/self.joint_dist.sum()     
        self.T_since_action = 0
        
    def update_joint(self, eps = 1e-8):
        self.get_T_()        
        if self.factorize: 
            PX = self.P_X__s_only[:,self.curr_stim, None] * self.P_X__theta_only[None, :,self.curr_stim]   
            self.joint_dist = (self.T_theta @ self.joint_dist.T).T
            self.joint_dist = (self.T_s @ self.joint_dist)
        else:
            PX = self.P_X__s_theta[:,:,self.curr_stim] 
            self.joint_dist = (self.T_theta @ self.joint_dist.T).T
            self.joint_dist = (self.T_s @ self.joint_dist)
        self.joint_dist = PX * self.joint_dist 
        # self.joint_dist = np.clip(self.joint_dist, a_min = eps, a_max = 1-eps)
        self.joint_dist = self.joint_dist/self.joint_dist.sum()    
        self.T_since_action += 1 

    def r_mask(self, r):
        return np.array([[1-r, r, 1-r]]).T 

    def get_T_(self):
        delta_max = 2*self.ITI_PM
        min_leave_time = self.ITI_mean - self.ITI_PM
        delta_curr = min(delta_max - 1, self.T_since_action - min_leave_time) 
        ITI_leave_prob = 1 / (delta_max - delta_curr) if self.T_since_action >= min_leave_time else 0
        self.T_s[0,2] = ITI_leave_prob 
        self.T_s[2,2] = 1 - ITI_leave_prob
                
    # def get_likelihood_regression(self):    
    #     self.S_inds = self.where(self.safe_backbone_flat)[self.window_len: -self.window_len]
    #     self.NS_inds = self.where(1-self.safe_backbone_flat)[self.window_len: -self.window_len]
        
    #     self.S_coefs, self.NS_coefs = [np.empty((self.stim_num, self.window_len), dtype = object) for _ in range(2)]
    #     for self.coef_ind, self.sliding_ind in enumerate(self.window_range):
    #         print(f"processing input {self.coef_ind + 1}")
    #         for self.stim_ind in range (self.stim_num):
    #             self.handle_0_counts()           
    #             self.S_coefs[self.stim_ind, self.coef_ind] = LogisticRegression().fit(self.PGO_backbone[self.S_inds, None], self.stim_traj[self.stim_ind, self.S_inds + self.sliding_ind].T)
    #             self.NS_coefs[self.stim_ind, self.coef_ind] =  LogisticRegression().fit(self.PGO_backbone[self.NS_inds, None], self.stim_traj[self.stim_ind, self.NS_inds + self.sliding_ind].T)
                           
    # def handle_0_counts(self):
    #     if len(np.unique(self.stim_traj[self.stim_ind, self.S_inds + self.sliding_ind])) == 1:
    #         self.stim_traj[self.stim_ind, self.S_inds[0] + self.sliding_ind] = 0
    #         self.stim_traj[self.stim_ind, self.S_inds[1] + self.sliding_ind] = 1
    #     if len(np.unique(self.stim_traj[self.stim_ind, self.NS_inds + self.sliding_ind])) == 1:
    #         self.stim_traj[self.stim_ind, self.NS_inds[0] + self.sliding_ind] = 0
    #         self.stim_traj[self.stim_ind, self.NS_inds[1] + self.sliding_ind] = 1
        
    # def invert_theta_to_likelihood(self):
    #     self.S_likelihood, self.NS_likelihood = [np.zeros((self.stim_num, self.window_len)) for _ in range(2)]
    #     for self.theta_ind, self.theta in enumerate(self.theta_range):
    #         for self.stim_ind in range(self.stim_num):
    #             self.S_likelihood[self.stim_ind, :] = np.array( [self.S_coefs[self.stim_ind, i].predict( np.array([[self.theta]]) ) for i in range(self.window_len)] )[:, 0]
    #             self.NS_likelihood[self.stim_ind, :] = np.array( [self.NS_coefs[self.stim_ind, i].predict( np.array([[self.theta]]) ) for i in range(self.window_len)] )[:, 0]
    #         norm = self.S_likelihood + self.NS_likelihood
    #         self.S_likelihood = self.S_likelihood / norm
    #         self.NS_likelihood = self.NS_likelihood / norm 
            
    #         self.emperical_likelihood[0, self.theta_ind, :, :] = self.NS_likelihood
    #         self.emperical_likelihood[1, self.theta_ind, :, :] = self.S_likelihood
    
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    Bayesian span memory durations  
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    
    def bayes_init(self):
        self.acting = 0                                                                                                                                                        # network action 
        self.a = 1/self.exp_mean                                                                                                                                                     # p(safe)
        self.a_0 = 1 - self.a                                                                                                                                             # 1 - p(safe)
        # useful vectors 
        self.padding = np.ones(self.max_mem)                                                                                                                              # padding used as stimuli until step is larger than memory length  
        self.discount_forget = np.linspace(.85, .9, self.max_mem)[:, None]**(self.mem_range[None, :]-1)                                                                     # vector of every possible discounted past forgetting rate 
        self.bayes_reset()                                                                                                                                                # initialize new storage vectors

    def bayes_reset(self):
        # storage matrix initializations 
        self.Psafe_log, self.bayes_PGO_log, self.mem_log, self.model_free_log, self.weight_log, self.dist_log =  [np.empty(self.data_len, dtype = object) for _ in range(6)]             # Dims: [ number of trials ] = [ Trials ]
        self.curr_Psafe, self.curr_Punsafe, self.curr_PGO, self.curr_model_free, self.curr_mem = [np.ones((self.max_mem, self.trial_dur)) for _ in range(5)]                             # Dims: [ Free parameter X max trial duration] = [ Mem , Max ]
        self.curr_weight = np.ones((self.max_mem, self.trial_dur))
        self.curr_dist = np.ones((self.bayes_resolution, self.trial_dur)) 
        self.curr_n, self.running_stim = [np.zeros((self.max_mem, 1)) for _ in range(2)]                                                                                                 # Dims: [ Free parameter ] = [ Mem ]       
        self.bayes_dist = np.ones((self.max_mem, self.bayes_resolution))                                                                                                                 # Dims: [ Free parameter X bayes dist resolution] = [ Mem , Bayes ]
        self.weights = np.ones((self.max_mem, self.max_mem))                                                                                                                             # Dims: [ Free parameter X past m stimuli ] = [ Mem , Mem ]
        self.bayes_step = self.known = self.unsafe = self.loaded_n = 0                                                                                                                   # initilalizations of several variables
        
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    Trigger function
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    
    def get_bayes(self):
        self.bayes_init()
        for self.curr_trial in tqdm(range(self.data_len), desc="bayes loop"):                                                                                                                      # for each trial
            self.end = self.trial_ends[self.curr_trial] + 1 
            self.bayes_loop()                                                                                                                                             # run bayesian algorithm on trial 
            self.bayes_log()                                                                                                                                              # store bayes results

    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    bayes processing
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    
    def bayes_log(self):
        self.model_free_log[self.curr_trial] = self.curr_model_free[:, :self.end].copy()                                                                                  # record model free window avg stim
        self.Psafe_log[self.curr_trial] = self.curr_Psafe[:, :self.end].copy()                                                                                            # record SAFE state belief 
        self.bayes_PGO_log[self.curr_trial] = self.curr_PGO[:, 1:self.end+1].copy()                                                                                       # record PGO estimate
        self.mem_log[self.curr_trial] = self.curr_mem[:,:self.end].copy()
        
        """ note we only store largest memory weights and dist due to memory constraints"""
        self.weight_log[self.curr_trial] = self.curr_weight[:,:self.end].copy()
        self.dist_log[self.curr_trial] = self.curr_dist[:,:self.end].copy()
        
        self.curr_PGO[:, 0] = self.curr_PGO[:,self.end].copy()                                                                                                            # start next trial estimate with curr trial's end estimate
        if self.curr_trial % int(self.data_len/10) == 0:                                                                                                                  
            print(str(int(100 * self.curr_trial/self.data_len)) + "%")                                                                                                    # print percentage complete
                
    def bayes_loop(self):
        for self.curr_step in range(self.end):                                                                                                                            # for each step in current trial
            self.Inference_step()                                                                                                                                         # perform inference of state beliefs 
            self.Estimation_step()                                                                                                                                        # perform estimation of parameters
            self.bayes_step += 1 
            self.unsafe += 1 
            self.known = np.clip(self.known + 1, a_min = None, a_max = self.max_mem - 1)                                                                                  # clamp the number of known steps by the maximum free param
    
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    bayes calculations
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    
    """ Inference step """
    def Inference_step(self):
        self.update_n()                                                                                                                                                   # update number of consecutive gos
        self.update_b_c()                                                                                                                                                 # update inference variables
        self.update_psafe()                                                                                                                                               # update safe state probability 
        
    def update_n(self):
        self.stim_1 = self.backbone_flat[self.bayes_step]                                                                                                                 # get if curr stim is GO 
        self.curr_n = np.minimum(self.known, self.curr_n)                                                                                                                 # cap consec GOs by last known
        self.curr_n = (1-self.acting)*self.stim_1*(self.curr_n + 1)                                                                                                       # increment if curr GO
        self.curr_n = np.minimum(self.mem_range[:, None], self.curr_n)                                                                                                    # cap consec GOs by free param (usually mem size)
        
    def update_b_c(self):     
        self.b = self.a_0*self.curr_PGO[:, self.curr_step, None].copy()                                                                                                   # use curr estimate of PGO 
        self.b_n = (self.b**self.curr_n)                                                                                                                                  # b^n
        self.b_0 = 1 - self.b                                                                                                                                             # 1 - b
        self.b_n_0 = 1 - self.b_n                                                                                                                                         # 1 - b^n 
        self.b_sum_to_n_min_1 = self.b_n_0 / self.b_0                                                                                                                     # sum over i from 0 -> n-1 of b^i = (1-b^n) / (1-b)
        n_min_k = np.maximum(0, self.curr_n - np.arange(self.max_mem))                                                                                                    # create vector of n, n-1, n-2 .... 0 
        self.b_k_min_n = self.b**(n_min_k) - self.b**self.curr_n                                                                                                          # subtract b^n from sum through all possible consec GOs 
        self.c = self.a_0*self.b_0                                                                                                                                        # (1-a) * (1-b)
    
    def update_psafe(self):
        self.curr_Punsafe[:, self.curr_step] = (self.b_n/(1 - self.c * self.b_sum_to_n_min_1)).squeeze()                                                                  # calculate psafe = b^n / ( 1 - c * (1 - b^n)/(1-b))
        self.curr_Psafe[:, self.curr_step] = (1 - self.curr_Punsafe[:, self.curr_step])                                                                                   # get Psafe as a function of n and current PGO estimate  
        
    """ Estimation step """
    def Estimation_step(self):
        stim =  self.backbone_flat[self.bayes_step - self.max_mem+1 : self.bayes_step+1] if self.bayes_step > self.max_mem else self.padding                              # get window of stim (or padding for begining of process)
        self.last_m_stim = np.flip(stim)                                                                                                                                  # reverse order to get [new ... old] for matrix mulitiplication convinience 
        self.check_for_action()                                                                                                                                           # check if action occured
        self.update_PGO()                                                                                                                                                 # Handle probabilistic weighting and PGO update 
        if not self.stim_1 or self.acting:                                                                                                                                     # If NOGO or action 
            self.known = 0                                                                                                                                                # all states are known, reset known index to 0
            
    def check_for_action(self):
        self.acting = self.curr_step == self.act_times[self.curr_trial]                                                                                                        # Check if action was taken at curr step 
        self.R = self.rews[self.curr_trial] == 1                                                                                                                          # Check if reward was recieved at curr trial 
            
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    bayes specifics
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    
        
    def update_PGO(self):       
        """ 
        bayes types:
            standard =              simple memory window with weights for each stim in memory, for sufficient statistic
            soft =                  expert bayesian updating only after NOGO, for full posterior distribution 
            discount =              standard + decaying vector multiplied with the weights, for gradual forgetting
            hybrid =                weighted likelihood of each stim in memory, for full posterior distribution
            dynamic =              soft forgetting with dynamic discounting based on the estimated uncertainty
        
        """
        if self.bayes_type == 'standard':
            PGO_update, PGO_prior = self.standard_estimate()
        if self.bayes_type == 'soft': 
            PGO_update, PGO_prior = self.soft_estimate()
        if self.bayes_type == 'discount':
            PGO_update, PGO_prior = self.discount_estimate()
        if self.bayes_type == 'hybrid':
            PGO_update, PGO_prior = self.hybrid_estimate()
        if self.bayes_type == 'dynamic':
            PGO_update, PGO_prior = self.dynamic_estimate()
            
        self.update_remaining_currs(PGO_update, PGO_prior)
        if np.any(np.isnan(self.curr_PGO[:, self.curr_step])):                                                 # if anything fucked up print it out 
            print("nan present")
            
    def update_remaining_currs(self, PGO_update, PGO_prior):
        self.curr_model_free[:, self.curr_step] = np.cumsum(self.last_m_stim)/self.mem_range                   # model free estimate is just the average stim for each window length 
        self.curr_PGO[:, self.curr_step + 1] =  PGO_update + PGO_prior                                         # Update curr estimate (weighted average) of PGO 
        self.curr_mem[:, self.curr_step] = self.last_m_stim
            
    def standard_estimate(self):
        self.update_weights()                                                                                  # Update weights         
        stim_weights, prior_weights = self.normalize_weights(self.weights)                                     # normalize weights 
        PGO_update = stim_weights @ self.last_m_stim                                                           # weighted avg of last m stim: [Mem , Mem ] X [ Mem ] = [ Mem ], is PGO update for each value of free paramter (e.g. memory duration)
        PGO_prior = prior_weights * self.curr_PGO[:, self.curr_step]                                           # remainder of weight put into current prior
 
        self.curr_weight[:, self.curr_step] = stim_weights[-1]                                              # log stim weights (with masking)
        return PGO_update, PGO_prior            
    
    def soft_estimate(self, eps = 1e-5):
        self.update =  1 if self.curr_n[-1] > 0 else (self.bayes_range**self.loaded_n)*(1 - self.bayes_range)  # if NOGO, likelihoods are equal to sigma ^ n * ( 1 -sigma ) , else each likelihood is equal (assumed safe)
        self.loaded_n = self.curr_n[-1]                                                                        # prepare next step's n (since curr_n will be 0 if NOGO)
        self.bayes_dist = self.norm_by_sum( self.bayes_dist * self.update + eps, dim = -1)                     # normalize the posterior with an epsilon decay to uniform distributio
        PGO_update =  (self.bayes_dist * self.bayes_range).sum(-1)                                             # Update to the curr estimate of PGO (prior weight is 0 because posterior already used prior)

        self.curr_dist[:, :, self.curr_step] = self.bayes_dist                                                 # log bayes dist
        return PGO_update, 0
    
    def discount_estimate(self):
        self.update_weights()                                                                                  # Update weights   
        W = self.weights * self.discount_forget                                                                # gradual forgetting
        norm = self.discount_forget.sum(-1, keepdims= True)                                                    # get norm for weights 
        stim_weights =  W/norm                                                                                 # normalize weights
        prior_weights = (1-stim_weights.sum(-1))                                                               # get weight of prior 
        PGO_update = stim_weights @ self.last_m_stim                                                           # weighted avg of last m stim: [Mem , Mem ] X [ Mem ] = [ Mem ], is PGO update for each value of free paramter (e.g. memory duration)
        PGO_prior = prior_weights * self.curr_PGO[:, self.curr_step]                                           # remainder of weight put into current prior 
        return PGO_update, PGO_prior
    
    def hybrid_estimate(self, eps = 1e-5):                                                                     # hybrid has both memory (either hard cutoff or discounted forgetting) and distributions 
        self.update_weights()                                                                                  # Update weights   
        """"""""""""""""""""" 2 options to choose from """""""""""""""""""""
        F = np.tri(self.max_mem)                                                                               # hard cutoff forgetting
        # F = self.discount_forget                                                                               # gradual forgetting  
        """"""""""""""""""""" 2 options to choose from """""""""""""""""""""

        Z = (1-F)[:,:,None]                                                                                    # zero fill
        W = (self.weights * F )[:,:,None] 
        S = self.last_m_stim[:, None]                                                                          # get last M stim
        B = self.bayes_range[None, :]                                                                          # get every possible PGO 
        likelihood = (S*B + (1-S)*(1-B))[None,:,:]                                                             # the likelihood of each individual stim is PGO if GO or (1 - PGO) if NOGO = GO * PGO + ( 1 - GO ) * ( 1 - PGO )
                
        """ standard bayes """
        L = (likelihood * W + (1-W)*S) * F[:,:,None] + Z
        update = self.norm_by_sum(L.prod(1), dim = -1)                                                         # likelihood: multiply likelihood of each stim, by weight of each state
        """ exponential weighting """
        # update = self.norm_by_sum( ( likelihood ** W ).prod(1), dim = -1)                                    # likelihood: multiply likelihood of each stim, to the power of the weight, e.g. weight = 0 means all PGO are equally probable        

        "not sure what is going on here"
        """ ML estimate """ 
        # self.bayes_dist = self.norm_by_sum(self.bayes_range[None,:] * update + eps, dim = -1)                # get posterior distribution
        """ MAP estimate """ 
        # self.bayes_dist = self.norm_by_sum(self.bayes_dist * update + eps, dim = -1)                         # get posterior distribution

        self.bayes_dist = update                                                                               # get posterior distribution
        
        PGO_update =  (self.bayes_dist * self.bayes_range[None, :]).sum(-1)                                    # current estimate of PGO for all memory durations

        self.curr_dist[:, self.curr_step] = self.bayes_dist[-1]                                                 # log bayes dist
        self.curr_weight[:, self.curr_step] = W.squeeze(2)[-1]                                                  # log stim weights 
        return PGO_update, 0
    
        
    def dynamic_estimate(self):    
        self.update_weights()                                                                                  # Update weights   
        """"""""""""""""""""" 3 options to choose from """""""""""""""""""""
        # reliability =  1-.5*self.curr_PGO[:, self.curr_step, None]                                           # discount factor decreases with PGO
        reliability =  1 - np.abs(self.curr_PGO[:, self.curr_step, None] - .5)                                 # discount factor is maximal at .5 
        # reliability =  .5 + np.abs(self.curr_PGO[:, self.curr_step, None] - .5)                              # discount factor is minimal at .5
        """"""""""""""""""""" 3 options to choose from """""""""""""""""""""
        self.discount_forget = reliability**(self.mem_range[None, :]-1)                                        # compute dynamic discount factor for each memory duration
        W = self.weights * self.discount_forget                                                                # implement discounted forgetting
        norm = self.max_mem                                                                                    # the norm of the maximum memory such that more forgetting = more weight to prior 
        stim_weights =  W/norm                                                                                 # normalize weights
        prior_weights = (1-stim_weights.sum(-1))                                                               # get prior weight for each memory duration
        PGO_update = stim_weights @ self.last_m_stim                                                           # get weighted average of stim
        PGO_prior = prior_weights * self.curr_PGO[:, self.curr_step]                                           # get weighted prior 
        return PGO_update, PGO_prior
    
    
    def update_weights(self):
        self.weights = np.roll(self.weights, 1)                                                     # Shift weights to preserve weights after last known
        self.weights[:, 0] = 0                                                                      # Set first weight to 0 as default
        num = self.b_k_min_n                                                                        # Action + reward case numerator 
        denom = self.b_n_0                                                                          # Action + reward case denominator 
        if not self.acting:                                                                         # If no action
            b_0_b_n = (self.b_0*self.b_n )                                                          # (1-b) * b^n 
            num = b_0_b_n + self.a*num                                                              # No action case numerator 
            denom = b_0_b_n + self.a*denom                                                          # No action case denominator 
        W = num/(1e-20+denom)
        if not self.stim_1 or (self.acting and not self.R):                                         # If NOGO or action and no reward
            self.unsafe = 0                                                                         # Current state is unsafe, reset unsafe index to 0 
        W[:, self.unsafe:] = 1                                                                      # Update weights from last known unsafe to be 1 
        # self.weights[:, :self.known] =  W[:, :self.known]                                         # update unsafe probabilities up to where already known 
        self.weights[:, :self.known + 1] =  W[:, :self.known + 1]                                   # update unsafe probabilities up to where already known 
                
    def normalize_weights(self, weights):
        weights = weights * np.tri(self.max_mem)                                                    # Set upper triangle to zero to implement memory lengths 
        prior_weights = (1 - weights)* np.tri(self.max_mem)                                         # prior weight comes from any weights less than 1
        prior_weights = (prior_weights).sum(-1)[:,None]                                             # get total prior weight
        denom = self.mem_range[:, None]                                                             # Normalize weights to sum to 1 
        return weights/denom, (prior_weights/denom).squeeze()  
    
    def norm_by_sum(self, A, dim):
        return A / A.sum(dim, keepdims = True)   

    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    analyzing bayes
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
    def run_REG(self):
        self.get_bayes()
        self.get_indices(planted = False, eps_final = self.held_out, flatten = True)
        self.get_IV_DV()                                                             
        self.regress_to_DVs()

    def get_IV_DV(self):
        self.layer_names = ""
        self.num_layers = 0
        self.IV_loop()  

    def IV_loop(self):
        self.flatten_bayes()
        layers = [self.C_gate_flat, self.F_gate_flat, self.I_gate_flat, self.O_gate_flat, self.output_flat, self.LTM_flat]#, self.PC_flat]
        for l, name in enumerate(["CELL", "FORGET", "INPUT", "OUTPUT", "STM", "LTM"]):#, "PCs"]):
            if name in self.regress_on:
                self.layer_names = self.layer_names + "/" + name 
                if self.num_layers == 0:
                    self.IV = layers[l]
                else:
                    self.IV = np.concatenate((self.IV, layers[l]))
                self.num_layers += 1

    def regress_to_DVs(self):
        for DV, self.DV_name in zip(self.DVs, self.DV_names): 
            print(f"Regression for {self.DV_name}")
            self.heats[self.DV_name]= np.zeros((len(DV), len(self.IV))) 
            reg = LinearRegression().fit(self.IV.T, DV.T)  
            self.heats[self.DV_name] = abs(reg.coef_)
            self.regressor[self.DV_name] = reg
                        
    def get_reconstruction(self, mem):
        self.mem = mem
        self.get_IV_DV()                                                            # this flattened data has trajectory indexing
        print("reconstructing parameters")
        self.reconstruct_DVs()

    def reconstruct_DVs(self):
        for self.DV, self.DV_name in zip(self.DVs, self.DV_names):
            self.DV_mus[self.DV_name] = np.zeros((self.max_mem, self.split_num, self.max_traj_steps)) 
            self.recons[self.DV_name] = np.zeros((self.split_num, self.max_traj_steps))
            self.score[self.DV_name]= np.zeros(self.max_mem) 
            self.DV_r = self.regressor[self.DV_name]
            pred = self.DV_r.predict(self.IV.T)
            self.score[self.DV_name] = ((pred-self.DV.T)**2).mean(0)
            self.recon_splits()
                
    def recon_splits(self):
        for self.split_i, self.split_curr in enumerate(self.split):
            self.get_split_inds()
            for self.step in range(self.max_traj_steps):   
                self.get_trajectory_step_inds()        
                self.get_recon_mus()
                
    def get_recon_mus(self):
        split_step_data = self.IV[:, self.step_inds]
        reconstruction = self.DV_r.predict(split_step_data.T)[:, self.mem] if len(self.step_inds) > 0 else np.zeros(1)
        self.recons[self.DV_name][self.split_i, self.step] = reconstruction.mean()
        self.DV_mus[self.DV_name][:, self.split_i, self.step] = self.DV[:, self.step_inds].mean(-1)           
                
                    