import numpy as np; import torch; import scipy.linalg as lin; import pylab as plt; 
from numpy import dot; from numpy.linalg import norm;  import matplotlib.cm as cm; from scipy.stats import spearmanr as spr
class helper_functions():
    def __init__(self, **params):
        self.__dict__.update(params)
       
    """ analytical analysis helpers"""
    def get_cosin_similarity(self, a, b):
        a = a - a.mean()
        b = b - b.mean()
        return dot(a, b)/(norm(a)*norm(b))


    """ splitting data """ 
    
    def get_indices(self, col = "block", From = 0, Til = 1e9, stim_above = None,  stim_below = None, planted = None, plant_PGO = None,  plant_ID = None, prev_PGO = None, curr_PGO = None,\
        PGO_in_list = [], plant_PGO_in_list = [],  rew = None, needs_acted = False, flatten = False, prep_mus = True, eps_init = 0, eps_final = None, cross_validation = "testing",\
        til_action = False, from_action = False, align_on = "onset"):
        self.postprocess_vars( col, align_on, til_action, from_action, flatten, prep_mus)
        """
        Generates an array of indices trial_inds that pick out trials.
        trial: trial number from last context switch
        col (color): the variable to split on
        """
        
        cond_A = (self.trial >= From) * (self.trial <= Til) 
        cond_B = self.last_PGOs == prev_PGO if prev_PGO != None else 1 
        cond_C = self.PGOs == curr_PGO if curr_PGO != None else 1 
        cond_D = 1 if planted == None else self.plant_inds == int(planted)
        cond_E = self.plant_PGOs == plant_PGO if plant_PGO != None else 1
        cond_F = self.all_times != self.inaction if needs_acted == True else 1
        cond_G = self.stim_durs > stim_above if stim_above != None else 1
        cond_H = self.stim_durs < stim_below if stim_below != None else 1
        cond_I = self.rews == rew if rew != None else 1 

        cond_K = (self.eps > eps_init) * (self.eps < (eps_final or self.eps[-1]) )
        cond_L = self.plant_IDs == plant_ID if plant_ID != None else 1
        cond_M = 1 if cross_validation == None else self.eps < self.eps[-1]/2 if cross_validation == "training" else self.eps >= self.eps[-1]/2
        cond_N = 1 if len(PGO_in_list) == 0 else np.isin(self.PGOs, PGO_in_list)
        cond_O = 1 if len(plant_PGO_in_list) == 0 else np.isin(self.plant_PGOs, plant_PGO_in_list)
        
        # Boolean array that picks trials
        self.trial_inds = np.where( cond_A * cond_B * cond_C * cond_D * cond_E * cond_F * cond_E * cond_G * cond_H  * cond_I * cond_K * cond_L * cond_M * cond_N * cond_O)[0]
        self.postprocess_indices()
        
    def postprocess_vars(self, col, align_on, til_action, from_action, flatten, prep_mus):
        self.from_action = from_action
        self.til_action = til_action
        self.align_on = align_on 
        self.prep_mus = prep_mus 
        self.flatten = flatten 
        self.col = col

    def postprocess_indices(self):        
        self.ends = self.trial_ends[self.trial_inds] + 1        
        self.cum_ends = np.cumsum(self.ends)
        self.starts = np.insert(self.ends, 0, 0)
        self.cum_starts = np.cumsum(self.starts)
        self.acts = self.act_times[self.trial_inds]
        self.durs = self.stim_durs[self.trial_inds]
        self.pots = self.last_pot[self.trial_inds]
        self.cum_acts = (self.acts + self.cum_starts[:-1]).astype(int)
        self.proj_trials = self.trial[self.trial_inds]
        self.trial_percent = 1 if len(self.trial_inds) == 0 else self.proj_trials/(self.proj_trials.max()+1)    
        
        self.full_inds = np.hstack([np.arange(s, e) for s, e in zip(self.cum_starts, self.cum_ends)]).astype(int)
        self.full_til_act_inds = np.hstack([np.arange(s, e) for s, e in zip(self.cum_starts, self.cum_acts)]).astype(int)
        self.full_from_act_inds = np.hstack([np.arange(s+2, e) for s, e in zip(self.cum_acts , self.cum_ends)]).astype(int)
        self.pre_act, self.post_act = [np.zeros(len(self.full_inds)) for _ in range(2)]
        self.pre_act[self.full_til_act_inds] = 1
        self.post_act[self.full_from_act_inds] = 1
        
        if self.flatten == True:
            self.flattening()
        self.get_split()
    
    def get_split(self):
        if self.col == 'all':
            self.split_cond = np.zeros(len(self.trial_inds))
        else: 
            diff = 0 if None in self.plant_PGOs else np.around((self.plant_PGOs - self.PGOs).astype(float), 2)
            data = [self.PGOs, self.plant_PGOs, self.plant_IDs, self.last_PGOs, diff, self.rews]
            names = ['block', 'plant_PGO', 'plant_ID', 'last', 'diff', 'rew']
            self.split_cond = self.get_data_of(data, names, select = self.col)[self.trial_inds]  
        self.split_trial_cols = self.Cmap.to_rgba(self.split_cond.astype(float))
        self.split = np.unique(self.split_cond)
        self.split_num = len(self.split)       
        if self.prep_mus:
            self.preprocess_mus()
        
                                                   
    def get_split_inds(self, split_after = False):                                                                                  # for self.split_i, self.split_curr in enumerate(self.split):
        self.half = int(self.max_traj_steps/2)
        if split_after:
            self.split_inds = np.arange(len(self.split_cond))                                                                       # for if you want the trajectory without splitting
        else:
            self.split_inds = np.where((self.split_cond == self.split_curr))[0]                                                     # split inds is all trials in condition
       
        self.split_starts = self.cum_starts[self.split_inds]
        self.split_ends = self.cum_ends[self.split_inds]
        self.split_starts = self.split_starts 
        if self.til_action:                                                                                                         # for only using data until action
            self.split_ends = self.cum_acts[self.split_inds] + 2*(self.align_on == 'action')
        self.full_split_ind_pairs = np.array([np.arange(s, e) for s, e in zip(self.split_starts, self.split_ends)], dtype = object)
        self.full_split_inds_flat = np.hstack(self.full_split_ind_pairs)                                                            # all inds of split conditions
        self.get_split_alignment()
        
    def get_split_alignment(self):
        self.split_acts = self.acts[self.split_inds]
        self.aligned_acts = (self.acts[self.split_inds] - self.half).astype(int)
        self.aligned_durs = (self.durs[self.split_inds] - self.half).astype(int)
        self.aligned_pot = (self.pots[self.split_inds] - self.half).astype(int)
        data = [0, self.aligned_acts, self.aligned_durs, self.aligned_pot]
        self.offset = self.get_data_of(data, ['onset', 'action', 'W4L', 'last_pot'], select = self.align_on)
        self.label_offset = int(self.max_traj_steps/2) if self.align_on != 'onset' else 0
        self.traj_xlabel = f"Time(s) from {self.align_on}"
        self.aligned_starts = self.split_starts + self.offset

    """ trajectory processing """

    def get_trajectory_step_inds(self):
        running_inds = self.aligned_starts + self.step 
        post_start = running_inds >= self.split_starts 
        if self.from_action:
            post_start = post_start * (running_inds >= self.cum_acts[self.split_inds])
        pre_end = running_inds < self.split_ends - 1
        conds = np.where(post_start * pre_end)[0]        
        self.step_inds = running_inds[conds]                                                                                        # Specific step data for condition           
        self.step_split_cond = self.split_cond[conds]                                                                               # Each step in step inds' split conditions
        self.emp_act_prob[self.split_i, self.step] = sum(self.split_acts == (self.step + self.offset))
       
    def get_trajectory_diffs(self):
        self.GO_traj_diffs = np.zeros((self.PC_dim, self.split_num, self.max_traj_steps))
        self.NOGO_traj_diffs = np.zeros((self.PC_dim, self.split_num, self.max_traj_steps))
        for self.split_i, self.split_curr in enumerate(self.split):
            self.get_split_inds()
            for self.step in range(self.max_traj_steps):
                self.get_trajectory_step_inds()
                GO_step_inds = np.where(self.input_flat[1, self.step_inds])[0]
                NOGO_step_inds = np.where(self.input_flat[0, self.step_inds])[0]
                self.PC_mus[:, self.split_i, self.step] = self.PC_flat[:, self.step_inds].mean(-1) 
                self.PC_prev_mus[:, self.split_i, self.step] = self.PC_prev[:, self.step_inds].mean(-1) 
                self.PC_traj_diffs[:, self.split_i, self.step] = self.PC_diff[:, self.step_inds].mean(-1)
                self.GO_traj_diffs[:, self.split_i, self.step] = self.PC_diff[:, self.step_inds][:, GO_step_inds].mean(-1)
                self.NOGO_traj_diffs[:, self.split_i, self.step] = self.PC_diff[:, self.step_inds][:, NOGO_step_inds].mean(-1)
        
    def get_GO_NOGO_inds(self):
        act_cond = self.pre_act if self.til_action else self.post_act if self.from_action else 1 
        self.GO_inds = np.where(self.input_flat[1] * act_cond )[0]
        self.NOGO_inds = np.where(self.input_flat[0] * act_cond )[0]     
        self.ACTION_inds = self.all_inds if (not self.til_action and not self.from_action) else np.where(act_cond)[0]

    def flatten_bayes(self):
        self.dist_flat = np.hstack(self.dist_log[self.trial_inds])                                                                
        self.weight_flat = np.hstack(self.weight_log[self.trial_inds])                                                           
        self.mem_flat = self.flatten_trajectory(self.mem_log[self.trial_inds]) 
        self.bayes_Psafe_flat = self.flatten_trajectory(self.Psafe_log[self.trial_inds])
        self.bayes_PGO_flat = self.flatten_trajectory(self.bayes_PGO_log[self.trial_inds])
        self.model_free_flat = self.flatten_trajectory(self.model_free_log[self.trial_inds])
        self.truth_PGO_flat = self.flatten_trajectory(self.step_PGO[self.trial_inds])[None,:].repeat(self.max_mem, axis = 0)
        self.truth_SAFE_flat = self.flatten_trajectory(self.safe_backbone[self.trial_inds])[None,:].repeat(self.max_mem, axis = 0)
        self.DVs = [self.truth_PGO_flat, self.truth_SAFE_flat, self.bayes_PGO_flat, self.bayes_Psafe_flat, self.model_free_flat]
        self.DV_names = ["TRUE PGO", "TRUE SAFE", "BAYES PGO", "BAYES P SAFE", "MODEL FREE"]

    def flattening(self):
        self.MF_theta_flat =  np.vstack(self.flatten_trajectory(self.MF_theta_structured[self.trial_inds])).T
        self.theta_dist_flat = np.vstack(np.hstack(self.theta_dist_structured[self.trial_inds])).T
        self.flow_theta_flat = self.flatten_trajectory(self.flow_theta_structured[self.trial_inds])
        self.flow_belief_flat = self.flatten_trajectory(self.flow_belief_structured[self.trial_inds])
        self.factorized_theta_flat = self.flatten_trajectory(self.factorized_theta_structured[self.trial_inds])
        self.factorized_belief_flat = self.flatten_trajectory(self.factorized_belief_structured[self.trial_inds])
        self.lick_prob_flat = self.flatten_trajectory(self.lick_prob[self.trial_inds])
        self.value_flat = self.flatten_trajectory(self.value[self.trial_inds])
        self.Q_flat = self.flatten_trajectory(self.Q_values[self.trial_inds])
        self.output_flat = self.flatten_trajectory(self.net_output[self.trial_inds]) 
        self.input_flat = self.flatten_trajectory(self.net_input[self.trial_inds])
        self.consec_flat = self.flatten_trajectory(self.consec_stim[self.trial_inds])
        self.PGO_flat = self.flatten_trajectory(self.step_PGO[self.trial_inds])
        self.PSAFE_flat = self.flatten_trajectory(self.safe_backbone[self.trial_inds])
        self.plant_PGO_flat = self.flatten_trajectory(self.step_plant_PGO[self.trial_inds])
        self.LTM_flat = self.flatten_trajectory(self.LTM[self.trial_inds])
        self.I_gate_flat = self.flatten_trajectory(self.i_gate[self.trial_inds])
        self.F_gate_flat = self.flatten_trajectory(self.f_gate[self.trial_inds])
        self.C_gate_flat = self.flatten_trajectory(self.c_gate[self.trial_inds])
        self.O_gate_flat = self.flatten_trajectory(self.o_gate[self.trial_inds])
        self.PC_flat = self.proj.transform(self.output_flat.T).T        
        self.get_reg_vec()
        
        self.PC_Xlab = "PC1 projection" 
        self.PC_Ylab = "PC2 projection" 
        self.PC_Zlab = "PC3 projection" 
        self.net_belief_flat = -np.log(1-self.Q_reg.predict_proba(self.pred_from.T)[:,1])
        self.net_theta_flat = self.PGO_reg.predict(self.pred_from.T)
        # self.decoded_dist = self.decoded_dist_reg.predict_proba(self.pred_from.T).T
        print(spr(self.net_theta_flat, b = self.flow_theta_flat))
        spr(self.net_theta_flat, b = self.flow_theta_flat)
        if self.cog_map == 'regression':
            self.PC_flat[0, :] =  self.net_belief_flat
            self.PC_Xlab = r"$-\log(1 - \hat s^{net})$" 
            self.PC_flat[1, :] = self.net_theta_flat
            self.PC_Ylab = r"$\hat\theta^{net}$" 
        if self.cog_map == 'theory':
            self.PC_flat[0, :] = self.flow_belief_flat.astype('float32') #- self.bias      
            self.PC_flat[1, :] =  self.flow_theta_flat  
            self.PC_Xlab = r"$-\log(1 - \hat s^{net})$" 
            self.PC_Ylab = r"$\hat\theta$" 
        
        # """ clip PC1 """
        # if self.cog_map != None:
        #     self.PC_flat[0] = np.clip(self.PC_flat[0], a_min = None, a_max = 4)
        #     self.PC_flat[1] = np.clip(self.PC_flat[1], a_min = 0, a_max = 1)
            
        self.PC_labs = [self.PC_Xlab, self.PC_Ylab, self.PC_Zlab]
        self.make_prev_diff()
        
    def make_prev_diff(self):
        self.PC_diff = np.diff(self.PC_flat, prepend=0)
        self.PC_prev = np.concatenate((np.zeros((self.PC_dim, 1)), self.PC_flat[:, :-1]), -1)
        self.net_belief_diff = np.diff(self.net_belief_flat, prepend=0)
        self.net_belief_prev = np.concatenate((np.zeros(1), self.net_belief_flat[:-1]), -1)
        self.net_theta_prev = np.concatenate((np.zeros(1), self.net_theta_flat[:-1]), -1)
        self.net_theta_diff = np.diff(self.net_theta_flat, prepend=0)
        self.flow_belief_diff = np.diff(self.flow_belief_flat, prepend=0)
        self.flow_belief_prev = np.concatenate((np.zeros(1), self.flow_belief_flat[:-1]), -1)
        self.flow_theta_diff = np.diff(self.flow_theta_flat, prepend=0)
        self.flow_theta_prev = np.concatenate((np.zeros(1), self.flow_theta_flat[:-1]), -1)
        self.factorized_belief_diff = np.diff(self.factorized_belief_flat, prepend=0)
        self.factorized_belief_prev = np.concatenate((np.zeros(1), self.factorized_belief_flat[:-1]), -1)
        self.factorized_theta_diff = np.diff(self.factorized_theta_flat, prepend=0)
        self.factorized_theta_prev = np.concatenate((np.zeros(1), self.factorized_theta_flat[:-1]), -1)
        
    def get_reg_vec(self):    
        # self.RT_flat = .1*np.hstack(self.trial_var_to_step_var(self.wait_from_last)[self.trial_inds])
        singles = np.concatenate( ([PC[None,:] for PC in self.PC_flat]), 0)
        inter_interactions = np.concatenate( ([self.PC_flat[i, None,:]*self.PC_flat[(i+1)%self.PC_dim, None,:] for i in range(self.PC_dim)]), 0)
        within_interactions = np.concatenate( ([self.PC_flat[i, None,:]*self.PC_flat[i, None,:] for i in range(self.PC_dim)]), 0)
        # self.pred_from = np.concatenate((singles, inter_interactions, within_interactions), 0 )
        # self.pred_from = np.concatenate((singles, inter_interactions), 0 )
        # self.pred_from = singles
        # self.pred_from = self.output_flat
        self.pred_from = np.concatenate(( self.output_flat, self.C_gate_flat, self.LTM_flat, self.O_gate_flat, self.I_gate_flat, self.F_gate_flat ), axis = 0 )
        self.pred_from_N = self.pred_from.shape[0]

    def raw_to_structured(self, X, time_dim = 1, running_ind = 0):
        structured_array = np.empty(self.data_len, dtype = object) 
        for curr_trial in range(self.data_len):
            end = self.trial_ends[curr_trial] + 1 
            trial_data = np.empty(end, dtype = object)
            
            for curr_step in range(end):
                trial_data[curr_step] = X[running_ind] if time_dim == 1 else X[:, running_ind] if time_dim == 2 else X[:, :, running_ind] if time_dim == 3 else X[:, :, :, running_ind]               
                running_ind = running_ind + 1            
            structured_array[curr_trial] = trial_data
        return structured_array 
                
    def raw_to_function(self, func):
        self.i = 0
        for self.curr_trial in range(self.data_len):                           # for each trial
            self.PGO = self.PGOs[self.curr_trial]
            self.act = self.act_times[self.curr_trial]
            self.end = self.trial_ends[self.curr_trial] + 1        
            for self.curr_step in range(self.end):                             # for each step in current trial
                self.stim = int(self.backbone_flat[self.i])
                self.acting = int(self.curr_step == self.act)
                self.post_act = int(self.curr_step > self.act)
                func()
                self.i += 1        
            
    
    """ preprocessing and postprocessing """     

    def preprocess_mus(self):           
        self.output_mus, self.LTM_mus, self.I_gate_mus, self.F_gate_mus, self.C_gate_mus, self.O_gate_mus = [np.zeros((self.hid_dim, self.split_num, self.max_traj_steps)) for _ in range(6)] 
        # self.emp_act_prob, self.GO_mus, self.PGO_mus, self.lick_prob_mus, self.value_mus, self.QDIFF_mus = [np.zeros((self.split_num, self.max_traj_steps)) for _ in range(6)]
        self.PC_mus, self.PC_traj_diffs, self.PC_prev_mus = [np.zeros((self.PC_dim, self.split_num, self.max_traj_steps)) for _ in range(3)]
        self.emp_act_prob, self.GO_mus, self.PGO_mus, self.lick_prob_mus, self.value_mus, self.QDIFF_mus,\
        self.flow_theta_mus, self.flow_belief_mus, self.factorized_theta_mus, self.factorized_belief_mus, self.net_theta_mus, self.net_belief_mus, \
        self.flow_theta_vars, self.flow_belief_vars, self.factorized_theta_vars, self.factorized_belief_vars, self.net_theta_vars, self.net_belief_vars = [np.zeros((self.split_num, self.max_traj_steps)) for _ in range(18)]
        self.pred_from_mus = np.zeros((self.pred_from_N, self.split_num, self.max_traj_steps))
        self.act_col = np.zeros((self.split_num, self.max_traj_steps, 4))
        self.split_leg = [f"{self.col} {i}" for i in self.split]        
        self.split_cols = self.Cmap.to_rgba(np.linspace(0, 1, self.split_num))
            
        
    def norm_W(self, W):
        return W/(lin.norm(W, axis = 1)[:, None])

    def flatten_trajectory(self, trajectory):
        return np.hstack([self.to(T, 'np') for T in trajectory])
    
    def trial_var_to_step_var(self, var):
        return var*(1 + 0*(self.data.backbone))

    def to(self, X, to):
        # if X == None:
        #     return X
        if type(X) == list:
            return [self.to(x, to) for x in X]
        if to == 'np' and type(X) == torch.Tensor:
            return X.detach().cpu().numpy() 
        if to == 'tensor' and type(X) != torch.Tensor:
            return torch.from_numpy(X).to(self.device).float()                                  
        return X
    
    def get_data_of(self, data, names, select, tensor = False):                                 # Returns of the data (e.g. plant PGOs or block PGOs) selected
        if tensor == True: 
            return self.to(self.get_data_of(self.to(data, 'np'), names, select), 'tensor')
        names = np.array(names, dtype = str)
        return np.array(data[int(np.where(select == names)[0])])

    def correlate(self, X, Y):                                                                  # Dim 1 = variables, Dim 2 = Samples
        X = X - X.mean(-1, keepdims = True)
        Y = Y - Y.mean(-1, keepdims = True)    
        X_var = (X**2).sum(-1, keepdims = True)
        Y_var = (Y**2).sum(-1, keepdims = True)
        var = X_var @ Y_var.T
        if type(var) != float: 
            var = var.astype(float)
        return (X @ Y.T) / np.sqrt(var)
        
    def norm_by_sum(self, A, dim):
        return A / A.sum(dim, keepdims = True)
    
    def where(self, X, Y = None):
        if Y is None:
            return np.where(X)[0] 
        return np.where(X == Y)[0]
    
    def percent_complete(self, i, total):
        if (i % int(total/10)) == 0:
            print(str(int(100 * i/total)) + "%")                                                                                                    
            
    """ LSTM Dynamics helpers"""
    def get_null_norms(self, W, null):
            return self.to((W @ null).norm(dim=0), 'np')

    def sort_dynamic(self, W):
       return np.take_along_axis(np.take_along_axis(self.to(W, 'np'), self.sort_inds_0, 0), self.sort_inds_1, 1)
   
    def get_split_from_mu(self, mus):
        return [self.to(mu[:, self.split_i, :].T, 'tensor') for mu in mus]

    def get_mean_across_dim(self, X, dim):
        return [self.to(x.mean(dim)[None, :], 'tensor') for x in X]

    def get_layer(self, x, b, W, act):
        return act(self.clip(x + b) @ W.T)
 
    def clip(self, tensor):
        return torch.clip(tensor, -1, 1)

    """ mouse plotting helpers"""
    
    def smooth(self, X, bin_size, axis = 0, rescale = False):
        kernel = np.ones(bin_size) / bin_size
        smoothed_arr =  np.apply_along_axis(lambda m: np.convolve(m, kernel, mode='full'), axis=axis, arr= X)
        if rescale: 
            smoothed_arr -= smoothed_arr.min()
            smoothed_arr *= 1/smoothed_arr.max()
            smoothed_arr = smoothed_arr*(X-X.min()).max() + X.min()
        return smoothed_arr
    
    def get_survival(self, x):
        dim = len(x.shape)
        x = np.insert(x, 0, np.zeros(1), axis = -1)
        from_one = x[:, 1:] if dim == 2 else x[:, :, 1:]
        til_one = x[:, :-1] if dim == 2 else x[:, :, :-1]
        return from_one * np.cumprod(1 - til_one, axis = -1)

    def get_hazard(self, data):
        if data.shape[0] == 1:
            return data/(np.clip(1-data.cumsum()+data[0], a_min = 1e-5, a_max = None))
        return [p/(np.clip(1-p.cumsum()+p[0], a_min = 1e-5, a_max = None)) for p in data]
        
    def get_action_color(self):
        self.act_cols = np.empty(len(self.all_times), dtype = object)     
        for i, a in enumerate(self.all_times):
            self.act_cols[i] = 'g' if self.rews[i] > 0 else 'r' if a > self.inaction else 'orange'
            
    def get_PGO_standards(self):
        self.PGO_LEG = [f"PGO = {p}" for p in self.PGO_range]
        self.PGO_ALPHA = [.5]*self.PGO_N
        self.PGO_COLOR = self.Cmap.to_rgba(np.linspace(0, 1, self.PGO_N))
        self.PGO_PLOT = ["line"]*self.PGO_N

    def make_standard_plot(self, ys ,xlab, plots, alphs, cols, xax = None, ylab = None, xticks = None, title = None, xlim = None, ylim = None, hlines = None, err = None, leg = None, traj = False, save_SVG = False, fig_x = 10, fig_y = 5, show=True):
        if not hasattr(self, 'save_plot'):
            self.save_plot = False
            self.font_dict = {}

        if self.show:
            if traj: 
                self.init_traj_plot()
            else:
                from enzyme import FIGPATH, TEXPATH
                # force reload
                import enzyme
                from importlib import reload
                try:    
                    reload(enzyme)
                except:
                    pass

                TEXTWIDTH, init_mpl = enzyme.TEXTWIDTH, enzyme.init_mpl
                from enzyme.src.helper import save_plot

                plt = init_mpl(usetex=False)
                fig = plt.figure(figsize=(fig_x, fig_y)) if not self.save_plot else plt.figure(figsize=(TEXTWIDTH, TEXTWIDTH / 3)) 
        for i, (y, p, c, a) in enumerate(zip(ys, plots, cols, alphs)):
            c = c if c is not None else f"C{i}"
            x = np.arange(len(y)) if xax is None else xax[i] 
            if p == 'line':
                plt.plot(x, y, color = c, alpha = a)#, linewidth = 3) 
            if p == 'scatter':
                plt.plot(x, y, '-o', color = c, alpha = a)
            if p == 'bar':
                plt.bar(i, y, color = c, alpha = a)
            if hlines is not None:
                plt.axhline(hlines[i], color = c, alpha = a, linestyle = '-.') 
            if err is not None:
                plt.fill_between(x, y-err[i], y+err[i], color = c, alpha = a/2)
                
        if xticks is not None:            plt.xticks(np.arange(len(xticks)), xticks)
        if title is not None:            plt.title(title,  fontdict=self.font_dict)
        if ylim is not None:             plt.ylim([ylim[0], ylim[1]])
        if xlim is not None:             plt.xlim([xlim[0], xlim[1]])
        
        if leg is not None:
            if leg == 'heat':
                self.make_color_bar()
            # else:
            #     plt.legend(leg[:i+1] if hlines is None else leg, loc = 'upper right')
        if ylab is not None:
            plt.ylabel(ylab,  fontdict=self.font_dict)
        plt.xlabel(xlab,  fontdict=self.font_dict); 
        if save_SVG or self.save_plot:
            plt.savefig(self.save_path + f"{title}.svg")
            plt.savefig(self.save_path + f"{title}.png")
            print(f"saved to: " + self.save_path + f"{title}.png")
        if show:
            plt.show()

    def set_cog_map_labels(self, plt, subplot = False):
        if subplot: 
            plt.set_xlabel(f"{self.PC_Xlab}")
            plt.set_ylabel(f"{self.PC_Ylab}")
        else:
            plt.xlabel(f"{self.PC_Xlab}")
            plt.ylabel(f"{self.PC_Ylab}")
            