from tqdm import tqdm
from enzyme.src.mouse_task.net_dynamics import net_dynamics; import numpy as np; import pylab as plt; import seaborn as sns
from scipy.ndimage import convolve
 
"""class for simulating flow field dynamics """ 
class flow_fields(net_dynamics):
    def __init__(self, **params):
        self.__dict__.update(params)

    def get_phase_space(self, til_action = True, from_action = False):    
        self.round_dim = 1
        self.get_indices(col = 'all', flatten = True, til_action = til_action, from_action = from_action)
        PC_round = self.PC_prev.round(self.round_dim)      
        self.uniques = np.array([np.unique(np.arange(PC.min(), PC.max(), .1).round(self.round_dim)) for PC in PC_round], dtype = object)         
        self.dPC1, self.dPC2, self.PC_count = [np.zeros((2, len(self.uniques[0]), len(self.uniques[1]))) for _ in range(3)]       # dims: [Stim, PC1, PC2]
        self.get_GO_NOGO_inds()
                
        for stim_i, stim_inds in enumerate([self.NOGO_inds, self.GO_inds]):
            for PC1_i, PC1 in enumerate(self.uniques[0]):
                for PC2_i, PC2 in enumerate(self.uniques[1]):
                    PC_unique_inds = self.where((PC_round[0, :] == PC1)*(PC_round[1, :] == PC2))
                    inds = np.intersect1d(PC_unique_inds, stim_inds)
                    if len(inds) > 0:
                        self.dPC1[stim_i, PC1_i, PC2_i] = self.PC_diff[0, inds].mean(-1)
                        self.dPC2[stim_i, PC1_i, PC2_i] = self.PC_diff[1, inds].mean(-1)
                        self.PC_count[stim_i, PC1_i, PC2_i] = len(inds)
                        
    def get_quadrant_space(self):    
        self.get_indices(col = 'all', flatten = True, til_action = True)
        self.get_GO_NOGO_inds()                
        self.flow_quadrants = np.zeros((3, 2, 3, 3, 2)) # dims: agent (net/joint/factorized), stim (NOGO/GO), x_axis (belief), y_axis (theta), diff_variable (belief/theta)
        net_where = [self.net_belief_prev, self.net_theta_prev]
        bayes_where = [self.flow_belief_prev, self.flow_theta_prev]
        factorized_where = [self.factorized_belief_prev, self.factorized_theta_prev]
        net_what = [self.net_belief_diff, self.net_theta_diff]
        bayes_what = [self.flow_belief_diff, self.flow_theta_diff]
        factorized_what = [self.factorized_belief_diff, self.factorized_theta_diff]
        where_data = [net_where, bayes_where, factorized_where]
        what_data = [net_what, bayes_what, factorized_what]
        
        theta_top = .6; theta_bot = .3 
        belief_top = 1.4; belief_bot = .7 

        # write to self
        self.theta_top = theta_top; self.theta_bot = theta_bot; self.belief_top = belief_top; self.belief_bot = belief_bot
        
        for agent_i, (self.agent_where, self.agent_what) in enumerate(zip(where_data, what_data)):
            for stim_i, stim_inds in enumerate([self.NOGO_inds, self.GO_inds]):
                self.what_belief, self.where_belief, self.what_theta, self.where_theta = self.agent_what[0], self.agent_where[0],  self.agent_what[1], self.agent_where[1]
                
                W = self.where((self.where_belief > belief_top)*(self.where_theta > theta_top))
                self.top_right_inds = np.intersect1d(stim_inds, W)
                self.flow_quadrants[agent_i, stim_i, 2, 2, 0], self.flow_quadrants[agent_i, stim_i, 2, 2, 1] =\
                    self.what_belief[self.top_right_inds].mean(), self.what_theta[self.top_right_inds].mean()          

                W =  self.where((self.where_belief < belief_bot)*(self.where_theta > theta_top))
                self.top_left_inds = np.intersect1d(stim_inds, W)
                self.flow_quadrants[agent_i, stim_i, 0, 2, 0], self.flow_quadrants[agent_i, stim_i, 0, 2, 1] =\
                    self.what_belief[self.top_left_inds].mean(), self.what_theta[self.top_left_inds].mean()          

                W = self.where((self.where_belief > belief_top)*(self.where_theta < theta_bot))
                self.bot_right_inds = np.intersect1d(stim_inds, W)
                self.flow_quadrants[agent_i, stim_i, 2, 0, 0], self.flow_quadrants[agent_i, stim_i, 2, 0, 1] =\
                    self.what_belief[self.bot_right_inds].mean(), self.what_theta[self.bot_right_inds].mean()          
                    
                W = self.where((self.where_belief < belief_bot)*(self.where_theta < theta_bot))
                self.bot_left_inds = np.intersect1d(stim_inds, W)
                self.flow_quadrants[agent_i, stim_i, 0, 0, 0], self.flow_quadrants[agent_i, stim_i, 0, 0, 1] =\
                    self.what_belief[self.bot_left_inds].mean(), self.what_theta[self.bot_left_inds].mean()          
                
                W = self.where((self.where_belief > belief_bot)*(self.where_belief < belief_top)*(self.where_theta < theta_top)*(self.where_theta > theta_bot))
                self.mid_mid_inds = np.intersect1d(stim_inds, W)
                self.flow_quadrants[agent_i, stim_i, 1, 1, 0], self.flow_quadrants[agent_i, stim_i, 1, 1, 1] =\
                    self.what_belief[self.top_right_inds].mean(), self.what_theta[self.top_right_inds].mean()          
                
                W = self.where((self.where_belief < belief_bot)*(self.where_theta < theta_top)*(self.where_theta > theta_bot))
                self.top_left_inds = np.intersect1d(stim_inds, W)
                self.flow_quadrants[agent_i, stim_i, 0, 1, 0], self.flow_quadrants[agent_i, stim_i, 0, 1, 1] =\
                    self.what_belief[self.top_left_inds].mean(), self.what_theta[self.top_left_inds].mean()          
                
                W = self.where((self.where_belief > belief_bot)*(self.where_belief < belief_top)*(self.where_theta < theta_bot))
                self.bot_right_inds = np.intersect1d(stim_inds, W)
                self.flow_quadrants[agent_i, stim_i, 1, 0, 0], self.flow_quadrants[agent_i, stim_i, 1, 0, 1] =\
                    self.what_belief[self.bot_right_inds].mean(), self.what_theta[self.bot_right_inds].mean()                              
                            
                W = self.where((self.where_belief > belief_bot)*(self.where_belief < belief_top)*(self.where_theta > theta_top))
                self.top_right_inds = np.intersect1d(stim_inds, W)
                self.flow_quadrants[agent_i, stim_i, 1, 2, 0], self.flow_quadrants[agent_i, stim_i, 1, 2, 1] =\
                    self.what_belief[self.top_right_inds].mean(), self.what_theta[self.top_right_inds].mean()     

                W = self.where((self.where_belief > belief_top)*(self.where_theta < theta_top)*(self.where_theta > theta_bot))
                self.top_right_inds = np.intersect1d(stim_inds, W)
                self.flow_quadrants[agent_i, stim_i, 2, 1, 0], self.flow_quadrants[agent_i, stim_i, 2, 1, 1] =\
                    self.what_belief[self.top_right_inds].mean(), self.what_theta[self.top_right_inds].mean()     
                        
    """ behavior from flow plots """                         
    def flow_to_behavior(self):
        self.init_distributional_vars()
        self.joint_dist = np.ones((3, self.bayes_resolution))
        self.T_since_action = 0
        self.flow_act_log, self.flow_rew_log, self.flow_waiting_log, self.flow_waiting_pot_log, self.window_waiting_log, self.window_waiting_pot_log, self.flow_ana_curve, \
        self.factorized_act_log, self.factorized_rew_log, self.factorized_waiting_log, self.factorized_waiting_pot_log = [np.zeros(self.data_len, dtype = object) for _ in range(11)]        
        self.simulate_flow_trials()
        self.postprocess_flow_behavior()

    def postprocess_flow_behavior(self):
        self.flow_waiting_log = self.flow_waiting_log.astype(float)
        self.flow_waiting_pot_log = self.flow_waiting_pot_log.astype(float)
        self.factorized_PDF, _, _, _, self.factorized_xax = self.to_dist(self.factorized_waiting_log)
        self.factorized_pot_PDF, _, _, _, self.factorized_pot_xax = self.to_dist(self.factorized_waiting_pot_log)
        self.window_pot_PDF, _, _, _, self.window_pot_xax = self.to_dist(self.window_waiting_pot_log)
        self.flow_pot_PDF, _, _, _, self.flow_pot_xax = self.to_dist(self.flow_waiting_pot_log)
        self.window_PDF, _, _, _, self.window_xax = self.to_dist(self.window_waiting_log)
        self.flow_PDF, _, _, _, self.flow_xax = self.to_dist(self.flow_waiting_log)
        self.flow_pot_hazard = self.get_hazard(self.flow_pot_PDF)
        self.factorized_pot_hazard = self.get_hazard(self.factorized_pot_PDF)
        self.flow_hazard = self.get_hazard(self.flow_PDF)
        self.ana_curve_flat = np.vstack(self.flow_ana_curve)
        
        self.flow_rate_mu, self.flow_rate_std, self.factorized_rate_mu, \
        self.flow_rew_mu, self.flow_rew_std, self.factorized_rew_mu, \
        self.flow_act_mu, self.flow_act_std, self.factorized_act_mu, \
        self.flow_wait_mu, self.flow_wait_std, self.factorized_wait_mu = [np.zeros(self.PGO_N) for _ in range(12)]

        self.flow_wait_all, self.factorized_wait_all = [np.zeros((self.PGO_N, self.num_trials), dtype=object) for _ in range(2)] # last axis is over repititions

        self.ana_curve_mu, self.ana_curve_var = [np.zeros((self.PGO_N, self.thresh_N)) for _ in range(2)]

        for j, p in enumerate(self.PGO_range):
            self.get_indices(curr_PGO = p, planted = False, prep_mus = False, cross_validation = None)
            N = len(self.trial_inds)
            self.flow_rew_mu[j] = self.flow_rew_log[self.trial_inds].mean()
            self.flow_rew_std[j] = self.flow_rew_log[self.trial_inds].std() #/ np.sqrt(N)                             
            self.flow_act_mu[j] = self.flow_act_log[self.trial_inds].mean()
            self.flow_wait_mu[j] = (self.flow_waiting_log[self.trial_inds]).mean()
            self.flow_wait_std[j] = (self.flow_waiting_log[self.trial_inds]).std()# / np.sqrt(N)

            self.flow_rate_mu[j] = self.flow_rew_mu[j] / (self.flow_act_mu[j] + self.ITI_mean)
            self.ana_curve_mu[j] = self.ana_curve_flat[self.trial_inds, :].mean(0)
            self.ana_curve_var[j] = self.ana_curve_flat[self.trial_inds, :].std(0)
            self.flow_rate_std[j] = self.flow_rew_std[j] / (self.flow_act_mu[j] + self.ITI_mean)
         
            self.factorized_rew_mu[j] = self.factorized_rew_log[self.trial_inds].mean()
            self.factorized_act_mu[j] = self.factorized_act_log[self.trial_inds].mean()
            self.factorized_wait_mu[j] = (self.factorized_waiting_log[self.trial_inds]).mean()
            self.factorized_rate_mu[j] = self.factorized_rew_mu[j] / (self.factorized_act_mu[j] + self.ITI_mean)
         
            
        self.bayes_theta_corr, self.bayes_last_theta_corr, self.factorized_theta_corr, self.factorized_last_theta_corr, self.window_theta_corr, self.window_last_theta_corr = [np.zeros(self.num_trials) for _ in range(6)]
        self.bayes_from_switch, self.factorized_from_switch, self.window_from_switch =  [np.zeros((self.PGO_N, self.num_trials)) for _ in range(3)]
        self.bayes_from_switch_all, self.factorized_from_switch_all, self.window_from_switch_all =  [np.zeros((self.PGO_N, self.num_trials), dtype = object) for _ in range(3)]
        for t in range(self.num_trials):
            for j, p in enumerate(self.PGO_range):
                self.get_indices(From = t, Til = t, curr_PGO = p, planted = False, prep_mus = False,  cross_validation = None)               
                self.bayes_from_switch[j, t] = self.flow_waiting_log[self.trial_inds-1].mean() 
                self.factorized_from_switch[j, t] = self.factorized_waiting_log[self.trial_inds-1].mean() 
                self.window_from_switch[j, t] = self.window_waiting_log[self.trial_inds-1].mean() 

                self.bayes_from_switch_all[j, t] = self.flow_waiting_log[self.trial_inds-1]
                self.factorized_from_switch_all[j, t] = self.factorized_waiting_log[self.trial_inds-1]
                self.window_from_switch_all[j, t] = self.window_waiting_log[self.trial_inds-1]
            self.get_bayes_corrs(t)   
             
    def get_bayes_corrs(self, t):
        self.get_indices(From = t, Til = t, planted = False, prep_mus = False, cross_validation = None)
        target = self.flow_waiting_log[self.trial_inds-1]
        factorized_target = self.factorized_waiting_log[self.trial_inds-1]
        window_target = self.window_waiting_log[self.trial_inds-1]
        self.bayes_theta_corr[t] = self.correlate(X = self.PGOs[self.trial_inds], Y = target)
        self.bayes_last_theta_corr[t] = self.correlate(X = self.last_PGOs[self.trial_inds], Y = target)
        self.window_theta_corr[t] = self.correlate(X = self.PGOs[self.trial_inds], Y = window_target)
        self.window_last_theta_corr[t] = self.correlate(X = self.last_PGOs[self.trial_inds], Y = window_target)
        self.factorized_theta_corr[t] = self.correlate(X = self.PGOs[self.trial_inds], Y = factorized_target)
        self.factorized_last_theta_corr[t] = self.correlate(X = self.last_PGOs[self.trial_inds], Y = factorized_target)
        
    def simulate_flow_trials(self):
        self.factorize = True
        for self.curr_trial in tqdm(range(self.data_len), desc="generating factorized behavior"): 
            self.reset_flow_trial()
            self.run_flow_trial(store_curve = True)
            
        self.factorize = False
        for self.curr_trial in tqdm(range(self.data_len), desc="generating Bayes behavior"): 
            self.reset_flow_trial()
            self.run_flow_trial()
        
        self.window_agent = .5*np.ones(int(self.exp_mean*2))
        self.window_belief = .5*np.ones((3, 1))
        for self.curr_trial in tqdm(range(self.data_len), desc="generating window behavior"): 
            self.reset_flow_trial()
            self.run_window_trial()
        
    def run_window_trial(self):
        while self.ongoing: 
            self.curr_stim = 1 if self.safe else self.trial_stim[self.curr_step] if not self.ITI else self.trial_ITI[self.ITI_i]
            self.window_agent = np.roll(self.window_agent, -1)
            self.window_agent[-1] = self.curr_stim 
            self.window_theta_ind = np.argmin((self.window_agent.mean() - self.bayes_range)**2)

            self.get_T_()        
            PX =  self.P_X__s_theta[:,self.window_theta_ind,self.curr_stim]
            self.window_belief = (self.T_s @ self.window_belief)
            self.window_belief = PX[:,None] * self.window_belief 
            self.window_belief = self.window_belief/self.window_belief.sum()    
            self.T_since_action += 1 
            if self.acted == 0:
                self.window_action()
            self.trial_step()       
            self.curr_step += 1             

    def window_action(self):
        p_theta = np.zeros((self.bayes_resolution, 1))
        p_theta[self.window_theta_ind] = 1
        weighted_p_safe = (self.total_ana_rews*p_theta).sum(0)
        weighted_rew_rate = (p_theta*(self.total_ana_rews/(self.total_ana_acts + self.ITI_mean))).sum(0)
        opt_belief = weighted_p_safe[np.argmax(weighted_rew_rate)]
        curr_LLH = self.window_belief[1]
        if curr_LLH >= opt_belief:
            self.window_waiting_pot_log[self.curr_trial] = self.curr_step - self.last_pot_NOGO
            last_nogo = self.nogos[self.where(self.nogos<=self.curr_step)][-1]
            self.window_waiting_log[self.curr_trial] = self.curr_step - last_nogo #- 1
            self.curr_rew = (self.curr_step >= self.curr_stim_dur) 
            self.window_belief = self.T_r @ (self.r_mask(self.curr_rew) * self.window_belief)
            self.window_belief = self.window_belief/self.window_belief.sum()     
            self.T_since_action = 0
            self.acted = 1

    def reset_flow_trial(self):
        self.ongoing = True 
        self.PGO = self.PGOs[self.curr_trial]
        self.safe, self.ITI = [False for _ in range(2)]
        self.curr_step, self.ITI_i, self.acted = [0 for _ in range(3)] 
        self.curr_stim_dur =  int(np.clip(np.random.exponential(self.exp_mean), a_min = self.exp_min, a_max = self.exp_max))
        self.curr_ITI_dur = max(2,int(np.random.uniform(self.ITI_mean - self.ITI_PM, self.ITI_mean + self.ITI_PM))) 
        self.trial_stim = np.random.choice([0, 1], p = [1-self.PGO, self.PGO], size = self.curr_stim_dur)     
        self.trial_ITI = np.random.choice([0, 1], p = [1-self.PGO, self.PGO], size = self.curr_ITI_dur)    
        self.trial_stim[0] = 0
        self.nogos = self.where(self.trial_stim, 0)
        self.last_pot_NOGO = self.nogos[-1] 

    def run_flow_trial(self, store_curve =  False):
        while self.ongoing: 
            self.curr_stim = 1 if self.safe else self.trial_stim[self.curr_step] if not self.ITI else self.trial_ITI[self.ITI_i]
            if self.acted == 0:
                self.flow_action()
            self.update_joint()
            self.trial_step()       
            self.curr_step += 1             
            if store_curve:
                self.flow_ana_curve[self.curr_trial] = self.flow_ana_curve[self.curr_trial] + self.weighted_rew_rate
        if store_curve: 
            self.flow_ana_curve[self.curr_trial] = self.flow_ana_curve[self.curr_trial]/self.curr_step


    def flow_action(self, eps = 1e-8):
        sampled_thresh = self.theta_dist_to_optimal_thresh()
        p_state = self.joint_dist.sum(-1)
        curr_LLH = np.log(np.clip(p_state[1],a_min = eps, a_max = None))

        if curr_LLH >= sampled_thresh:
            last_nogo = self.nogos[self.where(self.nogos<=self.curr_step)][-1]
            self.curr_rew = (self.curr_step >= self.curr_stim_dur) 
            self.handle_bayes_action()                
            self.acted = 1 
            if self.factorize: 
                self.factorized_waiting_pot_log[self.curr_trial] = self.curr_step - self.last_pot_NOGO
                self.factorized_waiting_log[self.curr_trial] = self.curr_step - last_nogo #- 1
                self.factorized_rew_log[self.curr_trial] = self.curr_rew
                self.factorized_act_log[self.curr_trial] = self.curr_step #+ 1
            else:
                self.flow_waiting_pot_log[self.curr_trial] = self.curr_step - self.last_pot_NOGO
                self.flow_waiting_log[self.curr_trial] = self.curr_step - last_nogo #- 1
                self.flow_rew_log[self.curr_trial] = self.curr_rew
                self.flow_act_log[self.curr_trial] = self.curr_step #+ 1

            
    def theta_dist_to_optimal_thresh(self):

        p_theta = self.joint_dist.sum(0)[:,None]
        weighted_p_safe = (self.total_ana_rews*p_theta).sum(0)
        self.weighted_rew_rate = (p_theta*(self.total_ana_rews/(self.total_ana_acts + self.ITI_mean))).sum(0)
        # self.weighted_rew_rate = weighted_p_safe/((self.total_ana_acts*p_theta).sum(0) + self.ITI_mean)
        opt_belief = weighted_p_safe[np.argmax(self.weighted_rew_rate)]
        opt_belief = np.log(opt_belief)
        return opt_belief

    def trial_step(self):
        stim_ongoing = self.ITI == self.safe 
        stim_ending = self.curr_step >= (self.curr_stim_dur - 1 )
        ITI_ending = self.ITI_i >= (self.curr_ITI_dur - 1)
        if stim_ongoing and stim_ending:
            self.safe = True 
        if self.acted: 
            self.safe = False 
            self.ITI = True 
            self.ITI_i += 1                      
        if (stim_ongoing == False and self.safe == False and ITI_ending) or (self.curr_step == self.trial_dur - 2): 
            self.ongoing = False       
        
    def gen_flow_behavior(self):
        self.flow_to_behavior()
        self.plot_flow_wait_from_last_dists()
        self.plot_flow_behavior_dists()
        self.plot_bayes_from_switch()
        
    def plot_ICLR_flow_behavior(self, SVG = False):
        fig, ax = plt.subplots(3,1, figsize=(6,15))
        for i in range(self.PGO_N):
            ax[0].plot(self.flow_xax[i],self.flow_PDF[i], c = self.PGO_COLOR[i], linestyle= self.bayes_style, linewidth = 2);
            ax[1].plot(self.factorized_xax[i],self.factorized_PDF[i], c = self.PGO_COLOR[i], linestyle='--', linewidth = 2);
            ax[2].plot(self.wait_xax[i], self.wait_PDF[i], c = self.PGO_COLOR[i], linewidth = 2);
            ax[0].set_xlim([1,20])
            ax[1].set_xlim([1,20])
            ax[2].set_xlim([1,20])

        fig, ax = plt.subplots(1,1)
        for i in range(self.PGO_N):
            ax.plot(self.bayes_from_switch[i], c = self.PGO_COLOR[i], linestyle = self.bayes_style)
            ax.plot(self.wait_from_switch[i], c = self.PGO_COLOR[i])
        
    def plot_flow_behavior_dists(self):
        self.plot_simple_dist(self.flow_xax, self.flow_PDF, title = "flow pdf", x_min = 0, x_max = 20)
        self.plot_simple_dist(self.flow_pot_xax, self.flow_pot_PDF, title = "flow pot pdf", x_min = -2, x_max = 20, y_max = .5)
        self.plot_simple_dist(self.factorized_xax, self.factorized_PDF, title = "factorized pdf", x_min = -2, x_max = 20, y_max = .5)
        self.plot_simple_dist(self.window_xax, self.window_PDF, title = "window pdf", x_min = 0, x_max = 20, y_max = .5)
        self.plot_simple_dist(self.flow_pot_xax, self.flow_pot_hazard, title = "flow pot hazard", x_min = -2, x_max = 20, y_max = 5)
        self.plot_simple_dist(self.factorized_pot_xax, self.factorized_pot_hazard, title = "factorized pot hazard", x_min = -2, x_max = 20)
        self.plot_simple_dist(self.wait_pot_xax, self.wait_pot_PDF, title = "net pot pdf", x_min =0, x_max = 20, y_max = .5)
        self.plot_simple_dist(self.wait_pot_xax, self.wait_pot_hazard, title = "net pot hazard", x_min =0, x_max = 20, y_max = 5)
        
    def plot_simple_dist(self, xax, yax, title, x_min = -10, x_max = 20, y_max = None):
        for i in range(self.PGO_N):
            plt.plot(xax[i], yax[i], c = self.PGO_COLOR[i])
        if y_max is not None:
            plt.ylim([0,y_max])
        plt.xlim([x_min, x_max]); plt.title(title); plt.show()

    def plot_bayes_from_switch(self):
        self.make_standard_plot(ys = self.bayes_from_switch, cols = self.PGO_COLOR, alphs =  self.PGO_ALPHA,
            plots = self.PGO_PLOT, leg = self.PGO_LEG, xlab = "trial from block switch", title = "bayes avg wait from last nogo") 
        self.make_standard_plot(ys = [self.bayes_theta_corr, self.bayes_last_theta_corr], cols = ["C0", "C1"], alphs =  [1]*2 ,plots = ['line']*2,
            leg =  ["current PGO corr", "last PGO corr"], xlab = "trial from block switch", title = "bayes correlation difference")             
        
        self.make_standard_plot(ys = self.factorized_from_switch, cols = self.PGO_COLOR, alphs =  self.PGO_ALPHA,
            plots = self.PGO_PLOT, leg = self.PGO_LEG, xlab = "trial from block switch", title = "factorized avg wait from last nogo") 
        self.make_standard_plot(ys = [self.factorized_theta_corr, self.factorized_last_theta_corr], cols = ["C0", "C1"], alphs =  [1]*2 ,plots = ['line']*2,
            leg =  ["current PGO corr", "last PGO corr"], xlab = "trial from block switch", title = "factorized correlation difference")             

        self.make_standard_plot(ys = self.window_from_switch, cols = self.PGO_COLOR, alphs =  self.PGO_ALPHA,
            plots = self.PGO_PLOT, leg = self.PGO_LEG, xlab = "trial from block switch", title = "window avg wait from last nogo") 
        self.make_standard_plot(ys = [self.window_theta_corr, self.window_last_theta_corr], cols = ["C0", "C1"], alphs =  [1]*2 ,plots = ['line']*2,
            leg =  ["current PGO corr", "last PGO corr"], xlab = "trial from block switch", title = "window correlation difference")             
                
        self.make_standard_plot(ys = self.wait_from_switch, cols = self.PGO_COLOR, alphs =  self.PGO_ALPHA,
            plots = self.PGO_PLOT, leg = self.PGO_LEG, xlab = "trial from block switch", title = "network avg wait from last nogo") 
        self.make_standard_plot(ys = [self.PGO_corr, self.last_PGO_corr], cols = ["C0", "C1"], alphs =  [1]*2 ,plots = ['line']*2,
            leg =  ["current PGO corr", "last PGO corr"], xlab = "trial from block switch", title = "network correlation difference")           