import numpy as np
from enzyme import FIGPATH, TEXTWIDTH
from enzyme.src.helper import save_plot; from enzyme.src.mouse_task.data_processing import data_processing; import pylab as plt; import seaborn as sns; 
import scipy.linalg as lin;  import matplotlib as mpl;   from cycler import cycler;  import pickle as p; import matplotlib.colors as cm_cols
import torch; import torch.nn as nn; from scipy.stats import entropy as ent; import matplotlib.cm as cm;

class mouse_plotting(data_processing):
    show = plt.show
    def __init__(self, **params):
        self.__dict__.update(params)
    
    """ Plot single trials"""   
    def plot_episode(self, manager = None):
        self.height = .25
        end = int(self.end_times[self.curr_trial])
        self.xax = np.arange(0, end, 1/self.temp_resolution)
        self.beeps = np.repeat(self.backbone[self.curr_trial, :end], self.temp_resolution)
        self.act_color = 'g' if self.rew[self.curr_trial] != 0 else 'r'
        ends = np.cumsum(self.end_times+1)
        s = int(ends[-2]); e = int(ends[-1])
        for d, field in self.get_net_recent(manager, s, e):
            self.make_trial_plot(net_data = self.to(d, 'np').T, title = field)          
            
    """ Plot visualization of network activity"""     
    def get_net_recent(self, manager, s, e):
        fields = ["lick_prob", "pred value", "Q diff", "net output", "LSTM_LTM", "LSTM_f_gate", "LSTM_i_gate", "LSTM_c_gate", "LSTM_o_gate"]
        gates = manager.agent.gates.detach() 
        data = [manager.agent.action_probs[1, s : e],
                manager.agent.values[s : e],
                manager.agent.Q_values[1, s : e] - manager.agent.Q_values[0, s : e],
                manager.agent.outputs[:, s : e], 
                manager.agent.LTMs[:, s : e],\
                gates[0, :, s : e], gates[1, :, s : e], gates[2, :, s : e], gates[3, :, s : e]]
        return zip(data, fields)

    """ Plot visualization of trial"""     
    def make_trial_plot(self, net_data, title):
        if self.show: fig = plt.figure(figsize=(10,5)); self.ax = fig.add_subplot(1, 1, 1)
        self.ax.plot(net_data, alpha = .5)  
        self.ax.fill_between(self.xax, self.beeps, alpha = .5)    
        self.ax.fill_between(np.arange(self.W4L_end[self.curr_trial], self.end_times[self.curr_trial] + 1), y1 = self.height, color ='r', alpha = .5)
        self.ax.fill_between(np.arange(self.stim_end[self.curr_trial], self.W4L_end[self.curr_trial] + 1),y1 =  self.height, color ='g', alpha = .5)
        self.ax.fill_between(np.arange(0, self.stim_end[self.curr_trial] + 1), y1 = self.height, color ='deeppink',  alpha = .5)
        self.ax.scatter(self.act_time[self.curr_trial]+1, 1.1, c = self.act_color)
        self.ax.set_ylim([min(0, net_data.min()-.5), max(1.2, net_data.max()+.5)])    
        self.ax.set_title(f"PGO = {self.PGO} with {title}")
        if self.show: plt.show()
        
###############################################################################################################################################

    """ all episode plotting"""  
    def plotting(self):
        self.plot_avg_act_rew_wait()
        self.plot_wait_from_last()
        self.plot_analytical()
        self.plot_from_switch()
                
    """ plot avg reward, action time and wait from last"""
    def plot_avg_act_rew_wait(self):
        self.make_standard_plot(ys = self.episode_rews, alphs = self.PGO_ALPHA, cols = self.PGO_COLOR,
            plots = self.PGO_PLOT, leg = self.PGO_LEG, xlab = r"episode (per $\theta$)", title = "avg Reward")
        self.make_standard_plot(ys = self.episode_rew_rate, alphs = self.PGO_ALPHA, cols = self.PGO_COLOR,
            plots = self.PGO_PLOT, leg = self.PGO_LEG, xlab = r"episode (per $\theta$)", title = "avg rew rate")
        self.make_standard_plot(ys = self.episode_action_times, alphs = self.PGO_ALPHA, cols = self.PGO_COLOR,
            plots = self.PGO_PLOT, leg = self.PGO_LEG, xlab = r"episode (per $\theta$)", title = "avg Action times")
        self.make_standard_plot(ys = self.episode_wait_from_last, alphs = self.PGO_ALPHA, cols = self.PGO_COLOR,
            plots = self.PGO_PLOT, leg = self.PGO_LEG, xlab = r"episode (per $\theta$)", title = "avg wait from last")

    """ plot wait from last behavior"""
    def plot_wait_from_last(self):
        self.make_standard_plot(xax = self.wait_xax, ys = self.wait_PDF, cols = self.PGO_COLOR, alphs =  self.PGO_ALPHA, xlim = [0, 35],
            plots = self.PGO_PLOT, leg = self.PGO_LEG, xlab = "wait from last nogo PDF")   
        
        self.make_standard_plot(xax = self.wait_xax, ys = self.wait_hazard, cols = self.PGO_COLOR, alphs =  self.PGO_ALPHA,   xlim = [0,35], ylim = [0,2],
            plots = self.PGO_PLOT, leg = self.PGO_LEG, xlab = "time", title = "Hazard function of wait from last NOGO")           
            
        self.make_standard_plot(xax = self.wait_pot_xax, ys = self.wait_pot_hazard, cols = self.PGO_COLOR, alphs =  self.PGO_ALPHA,  xlim = [-5,35], ylim = [0,2],
            plots = self.PGO_PLOT, leg = self.PGO_LEG, xlab = "time", title = "Hazard function of wait from last potential")                      
            
    """ plot network behavior as a function of trial from block switch"""
    def plot_from_switch(self):   
        self.make_standard_plot(ys = self.wait_from_switch, cols = self.PGO_COLOR, alphs =  self.PGO_ALPHA,
            plots = self.PGO_PLOT, leg = self.PGO_LEG, xlab = "trial from block switch", title = "avg wait from last nogo") 
                
        self.make_standard_plot(ys = [self.PGO_corr, self.last_PGO_corr], cols = ["C0", "C1"], alphs =  [1]*2 ,plots = ['line']*2,
            leg =  ["current PGO corr", "last PGO corr"], xlab = "trial from block switch", title = "correlation difference")             
        

    def plot_analytical(self):
        self.plot_numer_v_ana()
        self.plot_ana_vs_net_dists()
        self.plot_threshold_adaptation()
                
    """ plot numerical, analytical and network comparisons"""
    def plot_numer_v_ana(self):
        self.make_numer_v_ana_plot(A_ys = self.ana_acts, N_ys = self.numer_acts, agent_x_mus = self.wait_mu, agent_y_mus = self.act_mu, ylab = "action time")
        if self.show: plt.show()
        self.make_numer_v_ana_plot(A_ys = self.ana_rews, N_ys = self.numer_rews,  agent_x_mus = self.wait_mu, agent_y_mus = self.rew_mu, ylab = "P(reward)")
        if self.show: plt.show()
        self.make_numer_v_ana_plot(A_ys = self.ana_rates, N_ys = self.numer_rates, agent_x_mus = self.wait_mu, agent_y_mus = self.rate_mu, ylab = "reward rate")  
        # add errorbars
        ax = plt.gca()
        for i in range(self.PGO_N):
            c = self.PGO_COLOR[i]
            ax.errorbar(self.wait_mu[i], self.rate_mu[i], yerr = self.rate_std[i], xerr=self.wait_std[i], capsize=5, color=c)  
        if self.show: plt.show()

        # rew rates as a functino of theta
        self.make_standard_plot(ys = [self.ana_max_rate, self.rate_mu, self.fixed_thresh_rate],
            alphs = [1]*3, cols = [None]*3, xticks = self.PGO_range, ylim = [.95*self.fixed_thresh_rate.min(), 1.05*self.ana_max_rate.max()],
            title = "Reward Rate comparisons", xlab = "PGO", plots = ["scatter"]*3, 
            leg = ["analytical optimal rate", "network rate", "optimal fixed threshold"], show=False)
        ax = plt.gca()
        ax.errorbar(np.arange(len(self.rate_mu)), self.rate_mu, yerr = self.rate_std, capsize=5, color='C1')
        if self.show: plt.show()
        
    
        
    def make_numer_v_ana_plot(self, A_ys, N_ys,  agent_x_mus, agent_y_mus, ylab, xlim = None, ylim = None, SVG = False, curve_var = None):
        if self.show: fig = plt.figure(figsize=(8, 8))
        for i in range(self.PGO_N):
            c = self.PGO_COLOR[i]
            plt.plot(self.ana_xax, A_ys[i], color = c, linewidth = 5, alpha = .5)
            if curve_var is not None:
                plt.fill_between(self.ana_xax, A_ys[i]-curve_var[i], A_ys[i]+curve_var[i], color =  c, alpha = .5)   

            plt.plot(self.numer_xax, N_ys[i], "o", color = c)                     # plotting numerical sanity check
        for i in range(self.PGO_N):
            plt.plot(agent_x_mus[i], agent_y_mus[i], color = self.PGO_COLOR[i], marker = 'o', markeredgecolor = 'k', markersize = 10, alpha = 1)

        xlabel = "Mean Consecutive GOs before First Lick"
        plt.xlabel(xlabel);  plt.ylabel(ylab)
        leg = np.array([f"PGO = {pgo} {d}" for pgo in self.PGO_range for d in ["analytical", "numerical", "Network"] ])
        # plt.legend(leg, loc = "lower right"); 
        plt.title(f"analytical vs network")        
        if ylim is not None:             plt.ylim([0, ylim])
        if xlim is not None:             plt.xlim([0, xlim])        
        if SVG:                          self.save_SVG(f'analytical {ylab}')

    def plot_ana_vs_net_dists(self, SVG = False):
        if self.show: fig = plt.figure(figsize=(20,10))
        for i, PGO in enumerate(self.PGO_range):
            til = min(self.xax_range, self.dist_x_lim[i])
            # plt.hlines(self.wait_PDF[i].max(), np.argmax(self.wait_PDF[i]) + self.wait_xax[i].min(), self.ana_max_thresh[i], linewidth = 2, linestyle = '--', color = self.PGO_COLOR[i])
            # plt.scatter(self.ana_max_thresh[i], self.wait_PDF[i].max() , marker = '*', s = 200, color = self.PGO_COLOR[i], label = f'optimal {PGO} threshold')
            plt.plot(self.wait_xax[i], self.wait_PDF[i].T, c = self.PGO_COLOR[i], linewidth = 3, label = f'network {PGO} PDF')
            # plt.plot(self.flow_wait_mu[i], self.wait_PDF[i].max() , marker = '*', markeredgecolor = 'k', markersize = 16, color = self.PGO_COLOR[i], label = f'Joint bayes mean waiting time')
            # plt.hlines(self.wait_PDF[i].max(), np.argmax(self.wait_PDF[i]) + self.wait_xax[i].min(), self.flow_wait_mu[i], linewidth = 2, linestyle = '--', color = self.PGO_COLOR[i])

            # plt.step(self.wait_xax[i], self.wait_PDF[i].T, where ='mid', c = self.PGO_COLOR[i], linewidth = 3, label = f'network {PGO} PDF')
        plt.xlabel("Wait from last NOGO", fontdict =  self.font_dict); plt.ylabel("PDF", fontdict = self.font_dict)
        plt.xlim([1,25])
        # self.make_color_bar()
        if SVG: 
            self.save_SVG("net dists")
        if self.show: plt.show()
        
    def plot_bayes_network_rew_rate(self,  xlim = [1,21], SVG = False):
        if self.show: fig = plt.figure(figsize=(8, 8))
        plt.axvline(self.fixed_max_thresh.mean(), color = 'k', linewidth = 1, alpha = .4)
        # plt.axhline(self.fixed_thresh_rate.mean(), color = 'k', linewidth = 1, alpha = .4)
        plt.plot( self.factorized_wait_mu.mean(),  self.factorized_rate_mu.mean(), marker = 'x', markeredgecolor = 'k', markersize = 20, alpha = 1)
        for i in range(self.PGO_N):
            c = self.PGO_COLOR[i]
            # plt.plot(self.ana_xax, self.ana_curve_mu[i], color = c, linewidth = 5, alpha = 1)
            # plt.axhline(self.ana_rates[i].max(), linestyle = '--', color = c, linewidth = 1, alpha = .5)
            plt.plot(self.ana_xax, self.ana_rates[i], color = c, linewidth = 4, alpha = .8)
            # plt.plot(self.ana_xax[self.ana_rates[i].argmax()], self.ana_rates[i].max(), '+', color = c, markeredgecolor = 'k', markersize = 10, alpha = .8)
            # plt.plot( self.factorized_wait_mu[i],  self.factorized_rate_mu[i], marker = 'x', markeredgecolor = 'k', markersize = 13, alpha = 1)

        for i in range(self.PGO_N):
            c = self.PGO_COLOR[i]
            # plt.errorbar(self.wait_mu[i], self.rate_mu[i], yerr = self.rate_std[i], capsize=5, color=c, alpha = .45)  
            # plt.errorbar(self.flow_wait_mu[i], self.flow_rate_mu[i], yerr = self.flow_rate_std[i], capsize=5, color=c, alpha = .45)  
            # plt.errorbar(self.wait_mu[i], self.rate_mu[i], xerr = self.wait_std[i], capsize=5, color=c, alpha = .8)  
            # plt.errorbar(self.flow_wait_mu[i], self.flow_rate_mu[i], xerr = self.flow_wait_std[i], capsize=5, color=c, alpha = .8)  
            plt.plot( self.wait_mu[i],  self.rate_mu[i], color = c, marker = 'o', markeredgecolor = 'k', markersize = 13, alpha = 1)
            plt.plot( self.flow_wait_mu[i],  self.flow_rate_mu[i], color = c, marker = '*', markeredgecolor = 'k', markersize = 16, alpha = 1)
            # plt.fill_between(self.ana_xax, self.ana_curve_mu[i]-self.ana_curve_var[i], self.ana_curve_mu[i]+self.ana_curve_var[i], color =  c, alpha = .5)         

        plt.xlabel("mean consecutive GOs before first lick");  plt.ylabel("reward rate")
        plt.title("analytical vs network")        
        if xlim is not None:             plt.xlim(xlim)        
        if self.show: plt.show()
        if SVG:                          self.save_SVG(f'analytical reward rate')
        
    def plot_threshold_adaptation(self):
        if self.show: fig = plt.figure(figsize=(10, 5))
        # bot = self.rand_similarity*.99
        top = self.ana_similarity*1.01
        bot = self.ana_similarity*.8
        plt.bar(0, self.ana_similarity)
        plt.bar(1, self.net_similarity)
        plt.bar(2, self.fixed_similarity)
        plt.ylim([bot, top]); plt.title("Cosin similarity with optimal average wait from last NOGO")
        plt.legend(["analytical", "network", "optimal fixed threshold"])
        if self.show: plt.show() 
        if self.show: fig = plt.figure(figsize=(10, 5))
        optimal = self.ana_max_rate.mean()
        plt.bar(0, optimal/optimal)
        plt.bar(1, self.rate_mu.mean()/optimal)
        plt.bar(2, self.fixed_thresh_rate.mean()/optimal)
        plt.bar(3, self.rand_max_rate.mean()/optimal) 
        plt.title("Percent optimal reward rate averaged over PGO")
        plt.legend(["analytical", "network", "optimal fixed threshold", "optimal random acting"])
        if self.show: plt.show()
        
        self.make_standard_plot(ys = [self.wait_PDF_avg, self.ana_max_thresh, [self.fixed_thresh]*self.PGO_N], alphs = [1]*3, cols = [None]*3, 
            plots = ["scatter"]*3, xlab = "PGO", xticks = self.PGO_range, leg = ["Wait from last nogo", "optimal flexible threshold", "optimal fixed threshold"])
       
    
    def plot_theta_MSE_fast_v_stable(self):
        fig, ax = plt.subplots(1, figsize = (15, 15))
        plt.scatter(self.MF_mse[:,:, 0].mean(1), self.MF_mse[:,:,1].mean(1), marker = 'x', s = (1+np.arange(self.MF_num))*10, alpha = .35)
        plt.scatter(self.net_mse[:, 0].mean(), self.net_mse[:, 1].mean(), marker = 'o',  s = 200)
        plt.scatter(self.bayes_mse[:, 0].mean(), self.bayes_mse[:, 1].mean(), marker ='*',  s = 200);        
        plt.xlabel("first trial from switch"); plt.ylabel(f"{self.num_trials} trial"); plt.title("MSE error"); plt.legend(["MF", "network", "bayes"]); 
        if self.show: plt.show()

        fig, ax = plt.subplots(1, figsize = (15, 15))
        for i, label in enumerate([1,  self.num_trials]):
            plt.plot(self.PGO_range, self.net_ratio[:, i], label = f"network ratio {label} trial"); 
            plt.plot(self.PGO_range, self.bayes_ratio[:, i], linestyle = self.bayes_style, label = f"bayes ratio {label} trial"); plt.ylabel("Estimate MSE / Model Free MSE")
        plt.axhline(1); plt.xlabel("theta"); plt.legend(); plt.title("MSE ratio for estimate over (optimal) MF window"); 
        if self.show: plt.show()
        
        
                
###############################################################################################################################################

    def init_traj_plot(self):
        fig, ax = plt.subplots(1, 1, figsize = (self.fig_W, self.fig_H), tight_layout = True)
        self.trajectories_title = f"\naligned on {self.align_on}"
        xax = np.arange(self.max_traj_steps) 
        plt.xticks(ticks = xax, labels = xax - self.label_offset)
        plt.xlabel(self.traj_xlabel)
        
    def init_3D_plot(self):
        if self.show: self.fig = plt.figure( figsize = (15,15));     ax = self.fig.add_subplot(111, projection='3d')
        ax.set_xlabel('PC 1', fontdict=self.font_dict)
        ax.set_ylabel('PC 2', fontdict=self.font_dict)
        ax.set_zlabel('PC 3', fontdict=self.font_dict)
        return ax


    """ plot network activity in PCA basis"""
    def plot_on_PCA_basis(self):
        for angle in self.angles3D:
            ax = self.init_3D_plot()
            traj_color = self.Cmap.to_rgba(np.linspace(0, 1, self.max_traj_steps))
            for self.split_i in range(self.split_num):
                act_color = self.act_col[self.split_i, :self.max_traj_steps, :]
                split_color = self.split_cols[self.split_i]
                X, Y, Z = [self.PC_mus[dim, self.split_i, :] for dim in range(self.PC_dim)]
                ax.scatter3D(X, Y, Z, s = 50, c= traj_color, alpha = 1)
                ax.scatter3D(X, Y, Z, s = 50, c= act_color, alpha = .5)

                ax.plot3D(X, Y, Z , c = split_color)     
            ax.set_xlabel('PC 1');       ax.set_ylabel('PC 2');       ax.set_zlabel('PC 3')
            ax.view_init(angle[0], angle[1]);       
            if self.show: plt.show()
            
    def plot_interactable_PCA(self):
        self.get_indices(col = 'block', From = 0,  Til = 50, stim_above = None, stim_below = None, eps_init = 0, 
            planted = False, plant_PGO = None, plant_ID = None, prev_PGO = None, curr_PGO = None, rew = None, align_on = 'action', flatten = True)
        self.run_trajectory(plot = False)
        ax = self.init_3D_plot()
        traj_color = self.Cmap.to_rgba(np.linspace(0, 1, self.max_traj_steps))
        for self.split_i in range(self.split_num):
            act_color = self.act_col[self.split_i, :self.max_traj_steps, :]
            split_color = self.split_cols[self.split_i]
            X, Y, Z = [self.PC_mus[dim, self.split_i, :] for dim in range(self.PC_dim)]
            ax.scatter3D(X, Y, Z, s = 50, c= traj_color, alpha = 1)
            ax.scatter3D(X, Y, Z, s = 50, c= act_color, alpha = .5)
            ax.plot3D(X, Y, Z , c = split_color)     
        ax.set_xlabel('PC 1');       ax.set_ylabel('PC 2');       ax.set_zlabel('PC 3')
        ax.view_init(0,-90)
        if self.show: plt.show()
        
    def plot_means_on_PCA_basis(self):
        for angle in self.angles3D:
            ax = self.init_3D_plot()
            for self.split_i, self.split_curr in enumerate(self.split):
                self.get_split_inds()
                split_color = self.split_cols[self.split_i]
                proj = np.array([self.proj.transform(traj.cpu().mean(-1, keepdims=True).T) for traj in self.net_output[self.trial_inds][self.split_inds]]).squeeze(1)
                
                X, Y, Z = [proj[:, dim]  for dim in range(self.PC_dim)]
                ax.scatter3D(X, Y, Z, s = 50, color = split_color, alpha = .5)
            ax.set_xlabel('PC 1');       ax.set_ylabel('PC 2');       ax.set_zlabel('PC 3')
            ax.view_init(angle[0], angle[1]);
            if self.show: plt.show()
                                
    """ plot projection onto each PC across time and gate mean activity"""
    def plot_trajectories(self):
        self.plot_PCA_trajectories()
        self.plot_net_v_bayes_trajectories()
        self.plot_network_trajectories()
        self.plot_behavior_trajectories()
    
    def plot_PCA_trajectories(self):
        self.plot_PCA_projections()
        # self.plot_PCA_derivatives()
        
    def plot_PCA_PGO_corr(self):
        self.init_traj_plot()
        PCs = self.proj.transform(self.output_flat.T).T
        corr = np.zeros((self.PC_dim, self.max_traj_steps))
        self.get_split_inds(split_after = True)
        for PC in range(self.PC_dim):
            corr[PC, :] = self.get_PC_PGO_corr(PCs[PC,:])
        plt.plot(np.abs(corr.T), '-o');        plt.title(f"Strength of correlation of PCs to {self.col}");        plt.legend(["PC1", "PC2", "PC3"])
        if self.show: plt.show()
        
    def plot_PC2_PGO_corr(self):
        self.get_indices(col = 'block', planted = True, align_on = 'onset', flatten = True)
        self.run_trajectory(plot = False)
        PCs = self.proj.transform(self.output_flat.T).T
        block = self.get_PC_PGO_corr(PCs[1,:])
        self.get_indices(col = 'plant_PGO', planted = True, align_on = 'onset', flatten = True)
        self.run_trajectory(plot = False)
        PCs = self.proj.transform(self.output_flat.T).T
        plant = self.get_PC_PGO_corr(PCs[1,:])
        
        self.init_traj_plot();                          plt.plot(block.T);                                      plt.plot(plant.T)
        plt.title("Correlation of PC 2 to PGO");        plt.legend(["previous context", "new context"]);        
        if self.show: plt.show()
        
    def get_PC_PGO_corr(self, PC):
       corr = np.zeros(self.max_traj_steps)
       self.get_split_inds(split_after = True)
       for self.step in range(self.max_traj_steps):   
           self.get_trajectory_step_inds()
           if len(self.step_inds) > 0:
               corr[self.step] = self.correlate(X = PC[None, self.step_inds], Y = self.step_split_cond).squeeze()
       return corr

    def plot_PCA_projections(self, SVG = False):
        PC_num = self.PC_mus.shape[0] 
        for self.dim in range(PC_num):
            self.plot_general_trajectory(self.PC_mus[self.dim], name = f"{self.PC_labs[self.dim]}", SVG = SVG)
        
    def plot_network_trajectories(self):
        for name, activity in zip(self.activity_names, self.activities):
            self.init_traj_plot()
            for self.split_i in range(self.split_num):
                SE = abs(activity[:, self.split_i, :]).std(0)/np.sqrt(self.hid_dim)
                mu = activity[:, self.split_i, :].mean(0)
                self.plot_trajectory_split(mu, SE)
                plt.fill_between(np.arange(len(mu)), mu-SE, mu+SE, color =  self.split_cols[self.split_i, :], alpha = .05)   
                plt.title(f"{name} network activity" + self.trajectories_title)
            plt.legend()
            if self.show: plt.show()
            
    def plot_net_v_bayes_trajectories(self):
        self.plot_general_trajectory(self.flow_theta_mus, "bayes theta")
        self.plot_general_trajectory(self.net_theta_mus, "net theta")
        self.plot_general_trajectory(self.flow_belief_mus, "bayes belief")
        self.plot_general_trajectory(self.net_belief_mus, "net belief")
        self.plot_general_trajectory(self.PGO_mus, "True theta")
        self.plot_general_trajectory(self.GO_mus, "stimuli")
        
    def plot_net_bayes_same_plot(self, SVG = False):
        self.plot_general_trajectory(self.net_belief_mus, "net vs joint bayes belief",\
            bayes_mu = self.flow_belief_mus,  net_var = self.net_belief_vars, bayes_var = self.flow_belief_vars, SVG = SVG)
        self.plot_general_trajectory(self.net_theta_mus,"net vs joint bayes theta",\
            bayes_mu = self.flow_theta_mus, net_var =  self.net_theta_vars, bayes_var = self.flow_theta_vars, SVG = SVG, ylim = 1)
        self.plot_general_trajectory(self.net_belief_mus, "net vs factorized bayes belief",\
            bayes_mu = self.factorized_belief_mus,  net_var = self.net_belief_vars, bayes_var = self.factorized_belief_vars, SVG = SVG)
        self.plot_general_trajectory(self.net_theta_mus,"net vs factorized bayes theta",\
            bayes_mu = self.factorized_theta_mus, net_var =  self.net_theta_vars, bayes_var = self.factorized_theta_vars, SVG = SVG, ylim = 1)
        self.plot_general_trajectory(self.GO_mus, "stimuli")

    def plot_ICLR_NOGO_plot(self, SVG = False):
        belief_mus = [self.net_belief_mus, self.flow_belief_mus, self.factorized_belief_mus]
        belief_vars = [ self.net_belief_vars, self.flow_belief_vars, self.factorized_belief_vars]
        theta_mus = [self.net_theta_mus, self.flow_theta_mus, self.factorized_theta_mus]
        theta_vars = [self.net_theta_vars, self.flow_theta_vars, self.factorized_theta_vars]
        for i, (name, mu_data, var_data) in enumerate(zip(["belief", "theta"], [belief_mus, theta_mus], [belief_vars, theta_vars])):
            self.init_traj_plot()
            for self.split_i in range(self.split_num):
                self.plot_trajectory_split(mu_data[0][self.split_i, :], SE = var_data[0][self.split_i, :])
                self.plot_trajectory_split(mu_data[1][self.split_i, :], SE = var_data[1][self.split_i, :], bayes = True)
                self.plot_trajectory_split(mu_data[2][self.split_i, :], SE = var_data[2][self.split_i, :], bayes = True)
                plt.plot(mu_data[2][self.split_i, :], marker = 'x', c = 'k', linestyle = self.bayes_style, alpha = .6)

    def plot_behavior_trajectories(self):
        names = ["Q DIFF", "lick prob", "FIRST lick prob", "pred value"]
        mus =  [self.QDIFF_mus, self.lick_prob_mus, self.get_survival(self.lick_prob_mus), self.value_mus]
        for name, mu in zip(names, mus):
            self.plot_general_trajectory(mu, name)
            
    def plot_general_trajectory(self, mu, name, bayes_mu = None, net_var = None, bayes_var = None, SVG = False, ylim = None):
        self.init_traj_plot()
        for self.split_i in range(self.split_num):
            self.plot_trajectory_split(mu[self.split_i, :], SE = None if net_var is None else net_var[self.split_i, :])
            if np.any(bayes_mu != None): 
                self.plot_trajectory_split(bayes_mu[self.split_i, :], SE = None if bayes_var is None else bayes_var[self.split_i, :], bayes = True)
                
        # plt.legend()
        if ylim is not None:
            plt.ylim([0, 1])
        self.make_color_bar()
        plt.title(f"{name} {self.trajectories_title}", fontdict=self.font_dict)
        if SVG:
            if (bayes_mu is not None) and self.factorize: 
                name = name + ' (factorized bayes)'
            self.save_SVG(name = f"{name}")
        if self.show: plt.show()

    def plot_trajectory_split(self, mu, SE = None, bayes = False):
        xax = np.arange(self.max_traj_steps)
        col = self.split_cols[self.split_i, :]
        alpha = .05 if bayes else .15 
        linestyle = self.bayes_style if bayes else '-' 
        scatter_col = self.act_col[self.split_i, :len(mu), :]
        if SE is not None: 
            plt.errorbar(xax, mu, yerr=SE, color = col, alpha = .2, capsize=5)
        plt.scatter(xax, mu, c = scatter_col, alpha = .05)
        plt.plot(mu, color = col, linestyle = linestyle, label = self.split_leg[self.split_i])
        plt.xlabel("time",  fontdict=self.font_dict)

    
#################################################################################################
    """ Phase diagram flow plotting """ 
    def plot_2D_trajectory_with_velocity(self, SVG = False, ax=None, cue=None):
        self.run_trajectory(plot=False)
        self.get_trajectory_diffs()        
        traj_color = self.Cmap.to_rgba(np.linspace(0, 1, self.max_traj_steps-1))
        from enzyme.colors import c_go, c_nogo
        for stim_name, stim_data, c in zip(["GO", "NOGO"], [self.GO_traj_diffs, self.NOGO_traj_diffs], [c_go, c_nogo]):
            if cue is not None and stim_name != cue: continue
            
            if ax is None:
                import enzyme
                from importlib import reload
                try:    
                    reload(enzyme)
                except:
                    pass

                TEXTWIDTH, init_mpl = enzyme.TEXTWIDTH, enzyme.init_mpl

                from enzyme.src.helper import save_plot
                from enzyme.src.helper import mystep
                from matplotlib import ticker
                from enzyme.plot_helper import add_panel_label

                plt = init_mpl(usetex=False)
                plt.close('all')
                fig, ax = plt.subplots(1, figsize = (TEXTWIDTH, TEXTWIDTH/3))
                ax = plt.gca()


            ax.set_box_aspect(1)
            for self.split_i, self.split_curr in enumerate(self.split):
                mu = self.PC_prev_mus[:,self.split_i,1:]
                d_stim = stim_data[:, self.split_i, 1:]              
                act_color = self.act_col[self.split_i, 1:self.max_traj_steps, :]
                split_color = self.split_cols[self.split_i]
                for i, _ in enumerate(traj_color):
                    # kwargs = dict(markeredgewidth=plt.rcParams["lines.markeredgewidth"]/6, markersize=plt.rcParams["lines.markersize"]/2)
                    ax.plot(mu[0], mu[1], ls="none", c=traj_color[i], marker="none", markeredgecolor="xkcd:dark gray", alpha=0.5, markersize=3)
                    ax.plot(mu[0], mu[1], ls="none", c=act_color[i], marker="none", markeredgecolor="xkcd:dark gray", alpha=.5, markersize=3)


                ax.plot(mu[0], mu[1], c=split_color, alpha=0.5, zorder=-10)
                ax.quiver(mu[0], mu[1], d_stim[0], d_stim[1], color=c, alpha=0.5, angles = 'xy',  scale_units = 'xy', scale = 1, zorder=-5)
            ax.set_xlabel(self.PC_Xlab)
            ax.set_ylabel(self.PC_Ylab)
            ax.set_xlim([-.15, self.PC_prev_mus[0].max()+ .1 + d_stim[0].max()])
            ax.set_ylim([self.PC_prev_mus[1].min() -.05 + min(0, d_stim[1].min()), self.PC_prev_mus[1].max()+ .05+d_stim[1].max()])
            title = f"{stim_name} adaptation split by {self.col} for {self.cog_map} trajectory"
            ax.set_title(title)
            if SVG: 
                # self.save_SVG(title)
                save_plot(fig=fig, name=title, path=FIGPATH / "appendix figs", file_formats=["png", "svg"])
            if self.show: plt.show()

    def plot_trajectory_with_velocity(self):
        self.run_trajectory(plot=False)
        self.get_trajectory_diffs()        
        ax = self.init_3D_plot()
        traj_color = self.Cmap.to_rgba(np.linspace(0, 1, self.max_traj_steps-1))
        
        for self.split_i, self.split_curr in enumerate(self.split):
            mu = self.PC_prev_mus[:,self.split_i,1:]
            d_go = self.GO_traj_diffs[:,self.split_i,1:]
            d_nogo = self.NOGO_traj_diffs[:,self.split_i,1:]
            
            act_color = self.act_col[self.split_i, 1:self.max_traj_steps, :]
            split_color = self.split_cols[self.split_i]
            from enzyme.colors import c_go, c_nogo
        
            ax.scatter3D(mu[0], mu[1], mu[2], s=plt.rcParams['lines.markersize']**2, c=traj_color, alpha=.5)
            ax.scatter3D(mu[0], mu[1], mu[2], s=plt.rcParams['lines.markersize']**2, c=act_color, alpha=0.5)
            ax.plot3D(mu[0], mu[1], mu[2], linewidth=8, c=split_color, alpha=0.5)
            ax.quiver(mu[0], mu[1], mu[2], d_go[0], d_go[1], d_go[2], color=c_go, alpha=0.5)
            ax.quiver(mu[0], mu[1], mu[2],  d_nogo[0], d_nogo[1], d_nogo[2], color=c_nogo, alpha=0.5)
            ax.set_xlabel(self.PC_Xlab)
            ax.set_ylabel(self.PC_Ylab)
            ax.set_zlabel(self.PC_Zlab)        
            ax.view_init(90, 0);

        if self.show: plt.show()

    """ plot trajectory of constant go vs nogo"""
    def plot_constant_GO_NOGO_trajectory(self):
        scatter_til = 100
        GOs = 1 # plant ID for GOs type
        for angle in self.angles3D:
            ax = self.init_3D_plot()
            for plant_ID in [0, 1]:
                self.get_indices(col = 'block', plant_ID = plant_ID, planted = True, align_on = 'onset', flatten = True)
                c = 'C2' if plant_ID == GOs else 'r'
                M = 'H' if plant_ID == GOs else 'X'
                self.run_trajectory(plot = False)
                self.get_trajectory_diffs()
        
                for self.split_i in range(self.split_num):
                    split_color = self.split_cols[self.split_i]
                    X, Y, Z = [self.PC_prev_mus[dim, self.split_i, 1:] for dim in range(self.PC_dim)]
                    
                    """ for sanity checking diffs """
                    diffs = self.GO_traj_diffs if plant_ID == GOs else self.NOGO_traj_diffs            
                    dX, dY, dZ = [diffs[dim, self.split_i, 1:] for dim in range(self.PC_dim)]
                    ax.quiver(X,Y,Z,dX,dY,dZ)
                    ax.quiver(X,Y,Z,dX,dY,dZ)
                    """ for sanity checking diffs """
                    
                    ax.plot3D(X, Y, Z , c = c, linewidth = 15, alpha = .25)     
                    ax.plot3D(X, Y, Z, c = split_color, alpha = 1, linewidth = 2)
                    ax.scatter3D(X[1:scatter_til],Y[1:scatter_til],Z[1:scatter_til], marker=M, color = split_color, s = 100, alpha = .5)
                    ax.scatter3D(X[0], Y[0], Z[0],  s = 500, color = 'g', alpha = .5)
                    
                X, Y, Z = [self.PC_mus[dim, :, 0] for dim in range(self.PC_dim)]
                ax.plot3D(X,Y,Z, linewidth = 5, color = 'g', alpha = 1)
                ax.set_xlabel('PC 1');       ax.set_ylabel('PC 2');       ax.set_zlabel('PC 3')
                ax.view_init(angle[0], angle[1]);
            self.save_SVG(name = str(angle))
            if self.show: plt.show()
            
    def plot_stim_specific_flow_field(self, name = "NETWORK", til_action = True, from_action = False, SVG = False):
        self.get_phase_space(til_action = til_action, from_action = from_action)
        fig, self.ax =  plt.subplots(1, 2, figsize = (self.fig_W, self.fig_H), tight_layout = True)
        for self.stim_i, (stim_name, c) in enumerate(zip(["NOGO", "GO"],["r", "C2"])):
            self.plot_flow_field(U = self.dPC1[self.stim_i].T, V = self.dPC2[self.stim_i].T, c = c)
            self.ax[self.stim_i].set_title(f"{name} {stim_name} flow\n {'(pre action)' if til_action else ''} {'(post action)' if from_action else ''}")
        if SVG:
            self.save_SVG(f"{name} flow")
        if self.show: plt.show()
        
        # X, Y = np.meshgrid(self.uniques[0], self.uniques[1])
        # U, V = self.dPC1[0].T + self.dPC1[1].T, self.dPC2[0].T + self.dPC2[1].T
        # magnitude = np.sqrt(U**2 + V**2)
        # plt.streamplot(X,Y, U,V,  linewidth= .5 + 3*magnitude/magnitude.max())
        # self.set_cog_map_labels(plt)
        # plt.title("total flow"); if self.show: plt.show()
        # self.plot_flow_policy()
        # self.plot_flow_delta_amps(stim_i)
        
    def plot_flow_field(self, U, V, c = 'C1'):
        X, Y = np.meshgrid(self.uniques[0], self.uniques[1])
        magnitude = np.sqrt(U**2 + V**2)
        # self.ax[self.stim_i].streamplot(X,Y, U,V, color=c, linewidth= .5 + 3*magnitude/magnitude.max(), broken_streamlines = False)
        self.ax[self.stim_i].streamplot(X,Y, U,V, color=c, linewidth= 3*magnitude/magnitude.max(), broken_streamlines = False)
        # self.ax[self.stim_i].quiver(X,Y, U,V, color=c, angles = 'xy', scale_units = 'xy', scale = 1, alpha = 1)
        self.set_cog_map_labels(self.ax[self.stim_i], subplot = True)
    
    def plot_flow_policy(self):
        Qs = self.uniques[0]
        plt.plot(Qs, 1/(1+np.exp(-2*Qs))); plt.title("Policy Stochasticity")
        plt.xlabel(f"{self.PC_Xlab}"); plt.ylabel("action probability");
        if self.show: plt.show()
    
    def plot_flow_delta_amps(self, stim_i):        
        valid_inds = np.array(np.where(self.PC_count[stim_i] > 0 ))        
        PC1_amp = np.abs(self.dPC1[:, valid_inds[0], :]).mean(-2).T
        PC2_amp = np.abs(self.dPC2[:, :, valid_inds[1]]).mean(-1).T

        plt.plot(self.uniques[0], PC2_amp)
        plt.title("amplitude of avg THETA difference")
        plt.legend(["Response to NOGO", "Response to GO"])
        plt.xlabel(f"decoded {self.PC_Xlab}"); plt.ylabel(f"DELTA {self.PC_Ylab}");        
        if self.show: plt.show()

        plt.plot(self.uniques[1],PC1_amp)
        plt.xlabel(f"decoded {self.PC_Ylab}"); plt.ylabel(f"DELTA {self.PC_Xlab}")
        plt.title("amplitude of avg LICK VALUE difference"); plt.legend(["Response to NOGO", "Response to GO"]); 
        if self.show: plt.show()
     
    def plot_quadrant_space(self, SVG = False):
        self.flow_quadrants = np.nan_to_num(self.flow_quadrants, nan = 0)
        belief_axis = np.array([-9, -2, 5])
        theta_axis = np.array([-6,0,6])
        X, Y = np.meshgrid(belief_axis, theta_axis)

        fig, ax = plt.subplots(1, 1, figsize = (10, 10), tight_layout = True)
        for stim_i, stim_col in zip([0,1], ['r', 'g']):
            # fig, ax = plt.subplots(1, 1, figsize = (10, 10), tight_layout = True)
            for agent_i, a, e, h, w in zip([2, 0,1], [.85,.85, .85], ['k', None,'w'], ['|||||', None, '|||||'],[.0085, .02,.0085]):
                belief_std = self.flow_quadrants[agent_i, stim_i, :, :, 0].std()
                U = self.flow_quadrants[agent_i, stim_i, :, :, 0].T/belief_std
                theta_std = self.flow_quadrants[agent_i, stim_i, :, :, 1].std()
                V = self.flow_quadrants[agent_i, stim_i, :, :, 1].T/theta_std
                plt.quiver(X,Y, U,V, color= stim_col, angles = 'xy',  scale_units = 'xy', scale = 1, edgecolor= e,  hatch= h, width = w)
                plt.xticks(ticks = belief_axis, labels = ["low belief", "med belief", "high belief"])
                plt.yticks(ticks = theta_axis, labels = ["low theta", "med theta", "high theta"])    
            plt.xlim([-11, 9]); plt.ylim([-11, 8])
        if SVG:
            self.save_SVG("quadrant flow")

    

#################################################################################################
    """ PC clustering plotting """ 
    
    """ plot PC vs diff PC colored by different variables"""
    def plot_2D_stim_clusters(self):
        self.get_indices(planted = None, align_on = 'onset', flatten = True)
        self.run_trajectory(plot = False)
        self.get_GO_NOGO_inds()
      
        fig, ax = plt.subplots(3, 3, figsize = (20, 17), tight_layout = True)
        col_from = self.PGO_flat         # self.pre_act  # FOR COLOR BY PRE-ACT POST ACT
        
        for PC in range(self.PC_dim):
            for stim_col, inds in zip(['C2', 'C3'],[self.GO_inds, self.NOGO_inds]):
                inds = np.random.choice(inds, 5000)
                inner_col = self.Cmap.to_rgba([col_from[inds]]).squeeze(0)
                
                X = self.PC_flat[PC, inds]
                X_prev = self.PC_prev[PC, inds]
                Y = self.PC_flat[(PC+1) % self.PC_dim, inds]
                Y_prev = self.PC_prev[(PC+1) % self.PC_dim, inds]
                Z = self.PC_diff[(PC+2) % self.PC_dim, inds]
                
                ax[PC, 0].scatter(X, Y, s = 100, alpha = .5, color = stim_col)
                ax[PC, 0].scatter(X, Y, s = 10, alpha = 1, c = inner_col)
                ax[PC, 1].scatter(X_prev, Z, s = 100, alpha = .5, color = stim_col)
                ax[PC, 1].scatter(X_prev, Z, s = 10, alpha = 1, c = inner_col)
                ax[PC, 2].scatter(Y_prev, Z, s = 100, alpha = .5, color = stim_col)
                ax[PC, 2].scatter(Y_prev, Z, s = 10, alpha = 1, c = inner_col)
            ax[PC,0].set_title(f"PC{(PC)%self.PC_dim + 1} and PC{(PC+1)%self.PC_dim + 1} locations")
            ax[PC,0].set_xlabel(f"location in PC{(PC) % self.PC_dim + 1}")
            ax[PC,0].set_ylabel(f"location in PC{(PC+1)%self.PC_dim + 1}")
            ax[PC,1].set_title(f"effect of PC{(PC)%self.PC_dim + 1} location on PC{(PC+2)%self.PC_dim + 1} velocity")
            ax[PC,1].set_xlabel(f"location in PC{(PC) % self.PC_dim + 1}")
            ax[PC,1].set_ylabel(f"velocity of PC{(PC+2)%self.PC_dim + 1}")
            ax[PC,2].set_title(f"effect of PC{(PC + 1)%self.PC_dim + 1} location on PC{(PC+2)%self.PC_dim + 1} velocity")
            ax[PC,2].set_xlabel(f"location in PC{(PC+1) % self.PC_dim + 1}")
            ax[PC,2].set_ylabel(f"velocity of PC{(PC+2)%self.PC_dim + 1}")
            # self.save_SVG(name = f"{PC}")
        if self.show: plt.show()                   

    def plot_specific_2D_diffs(self):
        self.get_indices(planted = None, align_on = 'onset', flatten = True)
        self.run_trajectory(plot = False)
        self.get_GO_NOGO_inds()
     
        G_inds = np.random.choice(self.GO_inds, 5000)
        N_inds = np.random.choice(self.NOGO_inds, 5000)
        
        """ PC3 on PC2 GO diff """
        fig, ax = plt.subplots(1, 1, figsize = (15, 15), tight_layout = True)
        ax.scatter(self.PC_prev[2, G_inds], self.PC_diff[1, G_inds], s = 1000, c = 'C2', alpha = .2)
        inner_col = self.Cmap.to_rgba([self.pre_act[G_inds]]).squeeze(0)
        ax.scatter(self.PC_prev[2, G_inds], self.PC_diff[1, G_inds], s = 100, c = inner_col, alpha = .2)
        ax.set_ylim([-.5, .5]);         ax.set_title("Effect on PC2 velocities");         ax.set_xlabel("PC3");        
        if self.show: plt.show()
        
        """ PC2 on PC1 GO diff (use til action)"""
        fig, ax = plt.subplots(1, 1, figsize = (15, 15), tight_layout = True)
        ax.scatter(self.PC_prev[1, G_inds], self.PC_diff[0, G_inds], s = 1000, c = 'C2', alpha = .2)
        inner_col = self.Cmap.to_rgba([self.PGO_flat[G_inds]]).squeeze(0)
        ax.scatter(self.PC_prev[1, G_inds], self.PC_diff[0, G_inds], s = 100, c = inner_col, alpha = .2)
        ax.set_title("Effect on PC1 velocities");         ax.set_xlabel("PC2");        
        if self.show: plt.show()
        
        """ PC1 on PC2 NOGO diff (use from action)"""
        fig, ax = plt.subplots(1, 1, figsize = (15, 15), tight_layout = True)
        ax.scatter(self.PC_prev[0, N_inds], self.PC_diff[1, N_inds], s = 1000, c = 'C3', alpha = .2)
        inner_col = self.Cmap.to_rgba([self.PGO_flat[N_inds]]).squeeze(0)
        ax.scatter(self.PC_prev[0, N_inds], self.PC_diff[1, N_inds], s = 100, c = inner_col, alpha = .2)
        ax.set_title("Effect on PC2 velocities");         ax.set_xlabel("PC1");        
        if self.show: plt.show()
        
        
###############################################################################################################################################

    """ plot evolution of PCs throughout training """ 
    def get_episode_PC_evolution(self, round_dim = 2):
        self.get_indices(planted = False, eps_init = 1500, flatten = False)
        self.run_PCA(override_indices=True)
        self.PC_yax = np.empty((self.episodes, 2, self.PC_dim, self.PC_dim), dtype = object)                         # dim: [episode, GO/NOGO, PC, diff][xax]
        self.PC_all = self.proj.transform(self.flatten_trajectory(self.net_output).T).T
        self.uniques = np.array([np.unique(np.round(PC, round_dim)) for PC in self.PC_all], dtype = object)       
        
        for e in range(self.episodes):
            e_inds = np.where(self.eps == e)
            self.output_flat = self.flatten_trajectory(self.net_output[e_inds])
            self.PC_flat = self.proj.transform(self.output_flat.T).T            
            self.PC_prev = np.concatenate((np.zeros((self.PC_dim, 1)), self.PC_flat[:, :-1]), -1)
            PC_round = np.round(self.PC_prev, round_dim)
            self.get_GO_NOGO_inds()
            
            for stim_i, stim_inds in enumerate([self.GO_inds, self.NOGO_inds]):
                for PC in range(self.PC_dim):
                    for PC_diff in range(self.PC_dim):
                        self.PC_yax[e, stim_i, PC, PC_diff] = np.zeros(len(self.uniques[PC]))                        
                    for u_i, u in enumerate(self.uniques[PC]):
                        PC_unique_inds = np.where(PC_round[PC, :] == u)[0]
                        inds = np.intersect1d(PC_unique_inds, stim_inds)
                        for PC_diff in range(self.PC_dim):
                            self.PC_yax[e, stim_i, PC, PC_diff][u_i] = 0 if len(inds) == 0 else self.PC_diff[PC_diff, inds].mean(-1)
        self.plot_episode_PC_evolution()

    def plot_episode_PC_evolution(self):
        PC_Z = self.smooth(self.PC_yax, 100, 0)
        PC_Y = np.arange(len(PC_Z))
        
        for PC in range(self.PC_dim):
            for PC_diff in range(self.PC_dim):
                norm_Z = np.array([np.max(np.abs(np.stack(PC_Z[e, :, PC, PC_diff]))) for e in range(len(PC_Z))])[:,None]
                norm_X = np.max(np.abs(self.uniques[PC]))
             #   norm_X = norm_Z = 1 
                fig, ax = plt.subplots(1, 1, figsize = (19,9), subplot_kw=dict(projection='3d'), tight_layout = True)
                for stim_i, (stim_name, c) in enumerate(zip(["GO", "NOGO"], ["C2", "r"])):        
                    Z = np.vstack(PC_Z[:, stim_i, PC, PC_diff])
                    Z = self.smooth(Z, 10, axis = 1)[:, 5:-4]
                    X, Y = np.meshgrid(self.uniques[PC], PC_Y)
                    ax.plot_surface(X/norm_X, Y, Z/norm_Z, color = c, alpha = .65)
                ax.set_xlabel(f"PC{PC+1}")
                ax.set_zlabel(f"Delta PC{PC_diff+1}")
                ax.set_ylabel("time")
                ax.set_title(f"effect of PC{PC+1} on PC{PC_diff+1} response")
                ax.set_box_aspect([1,2, 1])
                if self.show: plt.show()
            
#################################################################################################

    """ plot correlations between inputs and principle components"""
    def plot_corrs(self):       
        inps = ["GO", "LICK", "REWARD"]
        points = int(self.inp_cross.shape[-1])
        cross_corr_tick_labels = np.linspace(-points, points, 5)/2
        cross_corr_ticks = np.arange(len(cross_corr_tick_labels))
        xax = np.linspace(0, cross_corr_ticks[-1], points)
        mpl.rcParams['axes.prop_cycle'] = cycler(color=['tab:blue', 'tab:orange', 'tab:green'])
        fig, ax = plt.subplots(4, 1, figsize = (10, 20), tight_layout = True)
        l1 = [f" PC{PC}" for PC in [1,2,3]]
        l2 = [f" PC{PC} auto correlation" for PC in [1,2,3]]
        ax[3].plot(self.inp_corrs, '-o', label = l1)

        for i, name in enumerate(inps):
            ax[i].plot(xax, self.inp_cross[i, :,:].T, label = l1)
            ax[i].set_xticks(cross_corr_ticks)
            ax[i].set_xticklabels(cross_corr_tick_labels)
            ax[i].set_title(f"Cross-correlation betwen {name} & PCs")
            ax[i].plot(xax, self.auto_corrs.T, '--', label = l2, alpha = .5)
            ax[i].legend(loc = "upper right")
            
        ax[3].set_xticks(np.arange(len(inps)));       ax[3].set_xticklabels(inps)
        ax[3].set_xlabel("input");                    ax[3].legend(loc = "upper right")
        ax[3].set_title("Correlation coeffs");
        mpl.rcParams['axes.prop_cycle'] =\
        cycler('color', ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',\
        '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']);        
        if self.show: plt.show()
        
    """ plot angle between LSTM output and each node in each LSTM gate"""
    def plot_net_angles(self, power = 3, sort = False):
        leg = ["PC1", "PC2", "PC3"]
        plt.plot(self.act_proj.T, '-o');                                plt.legend(leg)
        plt.xlabel("Action");plt.title("contribution to action");       
        if self.show: plt.show()
        for i, gate in enumerate(self.gate_names):
            contribution = self.gate_proj[:, i, :]**power
            plt.plot(np.sort(contribution, -1).T, '-o') if sort == True else plt.plot(contribution.T, '-o')
            plt.xlabel(f"Neuron (Sorted = {sort})");                    plt.legend(leg)
            plt.title(f"Contribution to {gate} to the power {power}");  
            if self.show: plt.show()
            
    """ plot eigenvalues of gates and their interactions"""
    def plot_gate_eigen(self):
        W_rand = np.random.randn(self.hid_dim, self.hid_dim)/np.sqrt(self.hid_dim)
        Ws = [self.Wi_cpu, self.Wo_cpu, self.Wc_cpu, self.Wf_cpu, W_rand]
        eigs = [lin.eig(w, right = True) for w in Ws];        
        eig_names = self.gate_names + ["Random guassian W / sqrt(N)"]  
        fig, ax = plt.subplots(2, 1, figsize = (8, 10), tight_layout = True)
        for v in eigs:
            v = v[0][abs(v[0]).argsort()[::-1]]   
            real_vals = np.real(v); imag_vals = np.imag(v)
            ax[0].scatter(real_vals, imag_vals, s = 85)
            ax[1].plot(real_vals**2 + imag_vals**2, '-o')    
        ax[0].set_title("eigenvalue real vs imaginary");      ax[1].set_title("Power spectrum")
        ax[0].set_xlabel("real");                             ax[1].set_xlabel("eigenvalue #")       
        ax[0].legend(eig_names);                              ax[1].legend(eig_names);               
        ax[0].set_ylabel("imaginary");                        
        if self.show: plt.show()

    """ Plot dynamic cellgate weight matrix variability"""
    def plot_weight_variability(self):
        sns.heatmap(self.split_var); plt.title(f"variance due to {self.col} (total variance: {self.split_var.sum():.3f})"); 
        if self.show: plt.show()
        sns.heatmap(self.time_var); plt.title(f"variance across time (total variance: {self.time_var.sum():.3f})"); 
        if self.show: plt.show()
        sns.heatmap(self.split_x_time_var); plt.title(f"ELEMENT-WISE: (normed {self.col} var) x (normed time var) \n(norm of result: {lin.norm(self.split_x_time_var):.3f})")
        if self.show: plt.show()
                
    """ Plot dynamic cellgate matrix dynamics"""
    def plot_eigen_variability(self):
        for val, vec, name in zip([self.mod_eig_vals_R, self.mod_eig_vals_I], [self.mod_eig_vec_R, self.mod_eig_vec_I], ['real', 'complex']):            
            self.make_standard_plot(ys = val.var(-1), cols = self.split_cols, \
                alphs = [.5] * self.split_num, leg =  self.split_leg, xlab = "time", plots = ['scatter'] * self.split_num, \
                title = f"variance between (modulated cellgate) {name} eigenvalues through time", traj = True)

            fig, ax = plt.subplots(1,max(2, self.split_num), figsize = (20,5), tight_layout = True)
            for split in range(self.split_num):
                ax[split].plot(np.abs(val[split]), linewidth = 3, alpha = .25)
                ax[split].set_title(f"{self.col} {split} (modulated cellgate) {name} abs(eigenvalues) through time"); ax[split].set_xlabel("time")
                ax[split].set_ylim([0, np.max(val**2)])
            if self.show: plt.show()
        
            # self.make_standard_plot(ys = (val**2).mean(-1), cols = self.split_cols, \
            #     alphs = [.5] * self.split_num, leg =  self.split_leg, xlab = "time", plots = ['scatter'] * self.split_num, \
            #     title = f"(modulated cellgate) {name} mean(eigenvalues^2) through time", traj = True)

            # self.make_standard_plot(ys = val.var(1), cols = self.split_cols, \
            #     alphs = [.5] * self.split_num, leg =  self.split_leg, xlab = "eigenvalue", plots = ['scatter'] * self.split_num, \
            #     title = f"variance of each (modulated cellgate) {name} eigenvalue across time")
            
###############################################################################################################################################

    def layer_theta_traj(self, est, dist = None):
        reps = np.floor(self.PGO_flat.shape[0]/ self.theta_traj_N).astype(int)
        print(f"{reps} reps")
        theta_est = np.zeros((reps,  self.theta_traj_N))
        theta_dist = 0 if dist is None else np.zeros((dist.shape[0],  self.theta_traj_N))
        for i in range(reps): 
            s = i * self.theta_traj_N
            e = (i+1) * self.theta_traj_N
            theta_est[i, :] = est[s:e] 
            theta_dist = theta_dist + (0 if dist is None else dist[:, s:e])           
        return theta_est, theta_dist/reps, theta_est.mean(0), theta_est.var(0)

    def get_p_TH_theta(self):
        return [self.theta_dist_flat[:, np.abs(self.flow_theta_flat - theta) < 1e-3].mean(1) for theta in self.bayes_range]

    def plot_fixed_theta_traj(self, s, e, SVG = False):
        _, bayes_theta_dist, bayes_mu, bayes_var = self.layer_theta_traj(self.flow_theta_flat, self.theta_dist_flat)
        _, _, factorized_bayes_mu, factorized_bayes_var = self.layer_theta_traj(self.factorized_theta_flat)
        _, _, net_mu, net_var = self.layer_theta_traj(self.net_theta_flat)
        _, _, input_mu, input_var = self.layer_theta_traj(self.input_flat[1])        
        _, _, MF_mu, MF_var = self.layer_theta_traj(self.MF_theta_flat[25])
        xax = np.arange(self.theta_traj_N)[:e-s]

        if self.show: fig, ax = plt.subplots(figsize = (25,10))
        # sns.heatmap(bayes_theta_dist[:, s:e]); ax.invert_yaxis()
        x = xax
        y = self.bayes_range
        # norm = cm_cols.SymLogNorm(linthresh=1e-5, vmin=0, vmax=1)
        norm = cm_cols.Normalize(vmin=0)#, vmax =1)
        plt.pcolormesh(x, y, bayes_theta_dist[:, s:e], norm=norm, cmap=plt.get_cmap("inferno"), shading = 'nearest')
        
        
        # dist_var =(bayes_theta_dist*(self.bayes_range[:,None]**2) -(bayes_theta_dist*self.bayes_range[:,None])**2).sum(0)
        # plt.fill_between(xax, (bayes_mu[s:e] - dist_var[s:e]), (bayes_mu[s:e] + dist_var[s:e]), alpha=.5)
        #sns.heatmap(bayes_theta_dist[:, s:e]); ax.invert_yaxis()
        # self.theta_fill(xax, MF_mu, MF_var, s, e, color = 'C3', label = "window (5x mean stim duration)")   
        self.theta_fill(xax, bayes_mu, bayes_var, s, e, color = 'C2', label = "joint bayes", linestyle = '--')   
        self.theta_fill(xax, factorized_bayes_mu, factorized_bayes_var, s, e, color = 'C3', label = "factorized bayes", linestyle = '--')   
        self.theta_fill(xax, net_mu, net_var, s, e, color = 'C0', label = "network")   
        plt.plot(xax,self.PGO_flat[:self.theta_traj_N][s:e], label = "theta", linewidth = 5, c = 'C1')
        title = "fixed theta trajectory" + (" (factorized)" if self.factorize else "")
        plt.title(title); plt.legend(); 
        if SVG:
            self.save_SVG(title)
        plt.show()
        
        # fig, ax = plt.subplots(figsize = (25,10))
        # # s = 0
        # e = 200
        # plt.plot(self.PSAFE_flat[s:e])
        # plt.plot(1-np.exp(-self.flow_belief_flat[s:e].astype(float)))
        # plt.plot(1-np.exp(-self.net_belief_flat[s:e].astype(float)))
        # plt.legend(["State", "bayesian belief", "network belief"])
        # if self.show: plt.show()

    def theta_fill(self, xax, mu, var, s, e, color, label, linestyle = '-', alpha = 1):
        plt.fill_between(xax, (mu[s:e]-var[s:e]), (mu[s:e]+var[s:e]), alpha = .5, color = color)
        l = plt.plot(xax, mu[s:e], alpha = alpha, color = color, label = label, linestyle = linestyle)  
        return l

    """ plot linear regression coefficients"""
    def plot_heat(self, fig_height, fig_width):
        fig, ax = plt.subplots(len(self.DV_names), 2, figsize = (fig_width, fig_height), tight_layout = True)
        for x, DV_name in enumerate(self.DV_names):
            heat = self.heats[DV_name]
            sns.heatmap(heat, ax = ax[x, 0])
            for i in range(self.num_layers):
                ax[x,0].axvline(i*self.hid_dim)
            ax[x, 0].set_title(DV_name + " (DV) \npredicted by " + self.layer_names)
            ax[x, 0].set_ylabel(f"{DV_name} memory length (DV)")
            ax[x, 0].set_xlabel("neuron (IV)")
            ax[x, 1].plot(self.score[DV_name])
            ax[x, 1].set_title("MSE")
            ax[x, 1].set_xlabel(f"{DV_name} (DV) mem") 
        fig.suptitle("REG abs(coefs)")
        if self.show: plt.show()

    """ Plot linear regression trajectories"""        
    def plot_traj(self, fig_height, fig_width):
        fig, ax = plt.subplots(len(self.DV_names), 2, figsize = (fig_width, fig_height), tight_layout = True)
        for x, DV_name in enumerate(self.DV_names):
            for split_i in range(self.split_num):
                col = self.split_cols[split_i, :]
                ax[x, 0].plot(self.DV_mus[DV_name][self.mem, split_i, :], color = col, label = self.split_leg[split_i])        
                ax[x, 1].plot(self.recons[DV_name][split_i, :], color = col, label = self.split_leg[split_i])
            ax[x, 0].set_title(DV_name + " (DV) \n GROUND TRUTH")
            ax[x, 0].set_xlabel("time")   
            ax[x, 1].set_title(DV_name + " (DV) \npredicted by neuron (IV)")
            ax[x, 1].set_xlabel("time")
        ax[x, 0].legend()
        ax[x, 1].legend()
        fig.suptitle(f"Bayesian parameter ground truth and reconstruction for mem {self.mem}\n")
        # self.save_SVG("bayes_param")
        if self.show: plt.show()
        
    def plot_bayes_vs_recon(self, mem): 
        self.plot_heat(fig_height = 15, fig_width = 10)
        self.plot_traj(fig_height = 15, fig_width = 10)
        self.plot_bayes_net_mean_corr(mem)
        
    def plot_bayes_net_mean_corr(self, mem):
        fig, ax = plt.subplots(1,1, figsize = (15, 15), tight_layout = True)
        ax.plot(self.split, self.split, '-o', c = "C2")
        ax.plot(self.split, self.GO_mus.mean(-1), '-o',c = "C3")
        for self.split_i, self.split_curr in enumerate(self.split):
            ax.scatter(x = self.split_curr, y = self.DV_mus['BAYES PGO'][mem][self.split_i].mean(), c = 'C1', s = 1500)
            ax.scatter(x = self.split_curr, y = self.recons['TRUE PGO'][self.split_i].mean(), c = 'C0', s = 1500, marker = '*')
            for r in range(self.max_mem - 1):
                ax.scatter(x = self.split_curr, y = self.DV_mus['BAYES PGO'][1 + r][self.split_i].mean(), c = 'C4', alpha = .4 + .4*r / self.max_mem)

        plt.legend(["ground truth", "model-free", "Bayesian estimation", "ANN reconstruction"], fontsize = 30, loc ='upper left')
        if self.show: plt.show()
        
        
    def plot_mem_PGO_corr(self):
        self.align_on = 'onset'
        self.mem_PGO_corr = np.zeros((self.PGO_N, self.max_mem))
        for self.split_i, self.split_curr in enumerate(self.split):
            self.get_split_inds()
            split_step_data = self.IV[:, self.full_split_inds_flat]           
            X = self.regressor['TRUE PGO'].predict(split_step_data.T)[:, None, 0].T
            Y = self.bayes_PGO_flat[:, self.full_split_inds_flat]
            self.mem_PGO_corr[self.split_i] = self.correlate(X, Y).squeeze()        
        sns.heatmap(self.mem_PGO_corr.T); plt.xlabel("PGO"); plt.title("network - bayes mem correlation"); 
        if self.show: plt.show()
        ML_mem = (self.mem_range*self.mem_PGO_corr).sum(-1)/self.mem_PGO_corr.sum(-1)
        plt.plot(ML_mem);  plt.xlabel("PGO"); plt.title("network - bayes mem correlation"); 
        if self.show: plt.show()
        
    def plot_bump_reconstruction(self, end, mem, plant_ID = None):        
        self.get_indices(planted = True, plant_ID = plant_ID, eps_init = self.held_out, align_on = 'onset', flatten = True)
        self.get_reconstruction(mem = mem)
        
        self.xax = np.arange(0, end, 1/self.temp_resolution)
        recons = self.recons['TRUE PGO'], self.DV_mus['BAYES PGO'][mem]
        if self.show: fig = plt.figure(figsize=(15,8))
        self.ax = fig.add_subplot(1, 1, 1)

        for recon, recon_name, linestyle, linewidth, alph in zip(recons, ["network reconstruction", "Bayesian estimation"], ['-', '--'], [5, 3], [.75, 1]):            
            self.ax.plot(recon[0, :end],  color = 'k', linestyle = linestyle, linewidth = linewidth, alpha = .5, label = recon_name)
            for split_i in range(self.split_num):
                y = recon[split_i, :end]
                x = np.arange(len(y))
                col = self.split_cols[split_i, :] 
                self.ax.plot(x, y,  color = col, linestyle = linestyle, linewidth = linewidth, alpha = alph)
                self.ax.scatter(x, y,  color = col, s = 250, alpha = alph)
            self.ax.set_xlabel("time",   fontdict=self.font_dict)
        self.ax.legend(loc = 'upper right', fontsize = 15)     
        self.ax.set_title(f"PLANT_ID = {plant_ID}")                   
        plt.ylim([0, 1])
        if self.show: plt.show()
    
    def plot_update_to_NOGO(self, mem = 25, model_free = False):
        fig, ax = plt.subplots(1,1, figsize = (10, 10), tight_layout = True);
        net_diff, bayes_diff = [np.zeros((self.PGO_N, 3)) for _ in range(2)]
        labels = ["1 preceeding GO", "3 preceeding GO", "5 preceeding GO"]
        markers = ['o', 's', 'v']
        steps = [1, 3, 5]
        
        unity = np.linspace(-.3, .3, 20)
        plt.plot(unity, unity)

        for i, (plant, update_step) in enumerate(zip(self.plant_type, steps)):
            self.get_indices(col = 'block', eps_init = self.held_out, planted = True, plant_ID = i, align_on = 'onset', flatten = True)
            self.run_trajectory(plot=False)
            self.get_reconstruction(mem)

            net_diff[:, i] = np.diff( self.recons['TRUE PGO'], axis = -1)[:, update_step]
            bayes_diff[:, i] = np.diff(  self.DV_mus['BAYES PGO'][mem], axis = -1)[:, update_step]
            xlab = 'bayesian delta'
            if model_free:
                bayes_diff[:, i] = np.diff(  self.DV_mus['MODEL FREE'][mem], axis = -1)[:, update_step]
                xlab = 'model free delta'
                
            plt.plot(bayes_diff[:, i], net_diff[:, i], '-', c = 'k', alpha = .75, linewidth = 1, zorder = -1)
            plt.scatter(bayes_diff[:, i][0], net_diff[:, i][0], color = self.split_cols[0], marker = markers[i], s = 200, label = labels[i])
            
            for split in range(self.split_num):
                plt.scatter(bayes_diff[split][i],  net_diff[split][i], color = self.split_cols[split], marker = markers[i], s = 500, edgecolors = 'k')
        
        plt.xlabel(xlab); plt.ylabel("Network delta"); plt.legend(loc=2, prop={'size': 20}); plt.xlim([-.1, .23]); plt.ylim([-.1, .23])
        # self.save_SVG("MF")
        if self.show: plt.show()
                
    """ Plot linear regression coefficient cosine similarity to PCs"""        
    # def plot_coef_PC_sim(self, mem = 20):
    #     regressor_N = len(self.regress_on)
    #     similarities = np.zeros((regressor_N, self.PC_dim))
    #     for r_i in range(regressor_N):
    #         for PC_i in range(self.PC_dim):
    #             first_neuron = int(r_i * self.hid_dim)
    #             final_neuron = int((r_i + 1) * self.hid_dim)
    #             reg = self.regressor['TRUE PGO'].coef_[mem, first_neuron : final_neuron]
    #             similarities[r_i, PC_i] = self.get_cosin_similarity(reg, self.basis[PC_i])
    #     plt.plot(similarities.T, 'o'); plt.legend(self.regress_on); plt.xlabel("PC dim")
        
    """ plot maximum (over mem durs) correlations between PCs and model variables """
    # def plot_PC_bayes_corrs(self):
    #     PC1_corr, PC2_corr, PC3_corr = [np.zeros((3, self.max_mem)) for _ in range(3)]
    #     stat_names = ["bayes PGO", "P safe", "mean relevance"]
    #     for m in range(self.max_mem):
    #         if m > 0:
    #             statistics = [self.bayes_PGO_flat[m], self.Psafe_flat[m], self.weight_flat[m, :m].mean(0) ]
    #             for stat_ind, (name, bayes) in enumerate(zip(stat_names, statistics)): 
    #                 PC1_corr[stat_ind, m] =  self.correlate(bayes, self.PC_flat[0,None]);
    #                 PC2_corr[stat_ind, m] =  self.correlate(bayes, self.PC_flat[1,None]);
    #                 PC3_corr[stat_ind, m] =  self.correlate(bayes, self.PC_flat[2,None]);

    #     fig, ax = plt.subplots(1,1, figsize = (20, 15), tight_layout = True);
    #     for PC_ind, (PC_name, PC_corr) in enumerate(zip(["PC1", "PC2", "PC3"],[PC1_corr, PC2_corr, PC3_corr])):
    #         for stat_ind, stat_name in enumerate(stat_names):
    #             abs_corr = np.abs(PC_corr[stat_ind])
    #             x = PC_ind + .1*stat_ind -.1
    #             plt.bar(x, np.max(abs_corr), width = .1, alpha = .5, color = ["C0", "C1", "C2"][stat_ind])
    #             self.addlabels(x, np.max(abs_corr), f"mem = {np.argmax(abs_corr)}")
    #         plt.xticks([0, 1, 2], ["PC1", "PC2", "PC3"]);  plt.ylabel("Abs correlation"); 
    #         plt.legend(stat_names); plt.title("Correlation (max across mem durations)\n between PCs and Model Variables")

    # def plot_PC_bayes_corrs_extended(self):
    #     self.get_indices(col = 'block', From = 0,  Til = 50, planted = True, align_on = 'action', flatten = True)
    #     self.run_trajectory(plot = False)
    #     self.flatten_bayes()
        
    #     stat_names = ["bayes PGO", "PGO update", "P safe", "mean relevance", "mean input", "last input"]
    #     F_corr, I_corr, O_corr, C_corr, STM_corr, LTM_corr, PC1_corr, PC2_corr, PC3_corr = [np.zeros((len(stat_names), self.max_mem)) for _ in range(9)]
    #     for m in range(self.max_mem):
    #         if m > 0:
    #             statistics = [self.bayes_PGO_flat[m],np.diff(self.bayes_PGO_flat[m], append = 0), self.Psafe_flat[m], self.weight_flat[m, :m].mean(0), self.mem_flat[:m, :].mean(0),  self.mem_flat[0, :]]
    #             for stat_ind, (name, bayes) in enumerate(zip(stat_names, statistics)): 
    #                 F_corr[stat_ind, m] = self.correlate(bayes, self.F_gate_flat.mean(0))
    #                 I_corr[stat_ind, m] = self.correlate(bayes, self.I_gate_flat.mean(0))
    #                 O_corr[stat_ind, m] = self.correlate(bayes, self.O_gate_flat.mean(0))
    #                 C_corr[stat_ind, m] = self.correlate(bayes, self.C_gate_flat.mean(0));
    #                 STM_corr[stat_ind, m] = self.correlate(bayes, self.output_flat.mean(0));
    #                 LTM_corr[stat_ind, m] = self.correlate(bayes, self.LTM_flat.mean(0));
    #                 PC1_corr[stat_ind, m] =  self.correlate(bayes, self.PC_flat[0,None]);
    #                 PC2_corr[stat_ind, m] =  self.correlate(bayes, self.PC_flat[1,None]);
    #                 PC3_corr[stat_ind, m] =  self.correlate(bayes, self.PC_flat[2,None]);

    #     fig, ax = plt.subplots(1,1, figsize = (20, 15), tight_layout = True);
    #     for PC_ind, (PC_name, PC_corr) in enumerate(zip(["PC1", "PC2", "PC3"],[PC1_corr, PC2_corr, PC3_corr])):
    #         for stat_ind, stat_name in enumerate(stat_names):
    #             abs_corr = np.abs(PC_corr[stat_ind])
    #             x = PC_ind + .1*stat_ind -.1
    #             plt.bar(x, np.max(abs_corr), width = .1, alpha = .5, color = ["C0", "C1", "C2", "C3", "C4", "C5"][stat_ind])
    #             self.addlabels(x, np.max(abs_corr), f"mem = {np.argmax(abs_corr)}")
    #         plt.xticks([0, 1, 2], ["PC1", "PC2", "PC3"]);  plt.ylabel("Abs correlation"); 
    #         plt.legend(stat_names); plt.title("Correlation (max across mem durations)\n between PCs and Model Variables")
    
    """ plot bayesian posterior distribution and compare to network regression to PGO """
    # def plot_bayes_through_steps(self, end = 5000):
    #     fig, ax = plt.subplots(1,1, figsize = (30, 10), tight_layout = True)
    #     pred_from = self.C_gate_flat # predict from cell only 
    #     net_pred_PGO = self.regressor['TRUE PGO'].predict(pred_from.T).T[-1]
    #     plt.plot(self.bayes_resolution - net_pred_PGO[:end]*self.bayes_resolution, alpha = .75)        
    #     plt.plot(self.bayes_resolution - self.bayes_PGO_flat[-1, :end]*self.bayes_resolution, alpha = .75)        
    #     plt.plot(self.bayes_resolution - self.PGO_flat[:end]*self.bayes_resolution, '--', linewidth = 3)
    #     sns.heatmap(np.flip(self.dist_flat[:,:end], 0), yticklabels = np.round(np.linspace(1, 0, self.bayes_resolution), 2))
    #     if self.show: plt.show()
        
    #     end = 5000
    #     #self.preprocess_data(testing)
    #     self.cog_map = True
    #     # self.run_PCA()
    #     #self.get_bayes()
    #     self.get_bayes_flow();
    #     self.get_indices(flatten = True)
    #     self.flatten_bayes()
    #     fig, ax = plt.subplots(1,1, figsize = (30, 10), tight_layout = True)
    #     # net_pred_PGO = self.PGO_reg.predict(self.pred_from.T).T
    #     # plt.plot(net_pred_PGO[:end], alpha = .75)     
    #     plt.plot(self.PC_flat[1,:end], '--', linewidth = 3)        
    #     plt.plot(self.PGO_flat[:end], '--', linewidth = 3)
    #     if self.show: plt.show(); 
        
    #     self.plot_PC_phase_space()
              
    #     fig, ax = plt.subplots(1,1, figsize = (30, 5), tight_layout = True)
    #     plt.plot(20*self.PGO_flat[:end], '--', linewidth = 3)
    #     bayes_dist =np.vstack(self.flatten_trajectory(self.raw_to_structured(self.joint_dist_log.sum(0), dim = 2)[self.trial_inds]))
    #     sns.heatmap(bayes_dist[:end].T)
    #     if self.show: plt.show();
    #     fig, ax = plt.subplots(1,1, figsize = (30, 5), tight_layout = True)
    #     plt.plot(20*self.PGO_flat[:end], '--', linewidth = 3)
    #     net_dist = self.dist_reg.predict_proba(self.pred_from.T)
    #     sns.heatmap(net_dist[:end, :].T)
    #     if self.show: plt.show()
    #     sns.heatmap(self.correlate(bayes_dist.T, net_dist.T));
        
###############################################################################################################################################

    def make_color_bar(self):
        clb = plt.colorbar(self.Cmap)
        clb.set_label('P(GO|unsafe)\n',  fontdict=self.font_dict)
        
    def save_SVG(self, name):
        plt.savefig(f"/home/johns/anaconda3/envs/PFC_env/PFC/Data/{name}.svg")
    
    def addlabels(self, x,y, s):
        plt.text(x,y,s)
     