import bz2; import pickle; import _pickle as c;   import copy; import pylab as plt; import matplotlib.patches as mpatches; import numpy as np

class data_compare_processor():
    def __init__(self):
        pass

    def plot_root(self, process_mouse = None, plot_each = None, plot_all = None):
        self.get_figure()
        self.plot_each = plot_each is not None
        self.plot_all = plot_all is not None
        for self.n_i, self.name in enumerate(self.names):           
            for self.m_i, self.mouse in enumerate(self.get_mice()):               
                self.do(process_mouse())
                self.do(plot_each)
            self.do(plot_all)
        # plt.show()
        plt.savefig(f"/home/johns/anaconda3/envs/PFC_env/PFC/Data/{self.title}.svg")

##########################################################################################################################################################################
    """ plot split types """          
##########################################################################################################################################################################

    """ split by PGO """ 
    def PGO_split(self):
        self.leg = "PGO"
        for p_i in range(self.PGO_N):
            P = self.get_ax()
            if self.plot_each:
                self.general_plot(P, d = self.data[p_i], col = self.PGO_COLS[p_i])
            if self.plot_all:
                mu, var = self.get_mu_var(self.data[:, p_i])
                self.general_plot(P, d = mu, col = self.PGO_COLS[p_i], err = var)
            self.plot_postprocess(P) 
                        
    """ split by N repetitions """ 
    def N_split(self):
        self.leg = "names"
        for n in range(self.N):
            P = self.get_ax()
            if self.plot_each:
                self.general_plot(P, d = self.data[n], col = self.NAME_COLS[self.n_i])
            if self.plot_all:
                mu, var = self.get_mu_var(self.data[:, n])
                self.general_plot(P, d = mu, col = self.NAME_COLS[self.n_i], err = var)
            self.plot_postprocess(P) 
            
    """ no split """ 
    def no_split(self):
        self.leg = "names"
        P = self.get_ax()
        if self.plot_each:
            self.general_plot(P, d = self.data, col = self.NAME_COLS[self.n_i])
        if self.plot_all:
            mu, var = self.get_mu_var(self.data)
            self.general_plot(P, d = mu, col = self.NAME_COLS[self.n_i], err = var/2)
        self.plot_postprocess(P) 
        
##########################################################################################################################################################################
    """ plot processing """          
##########################################################################################################################################################################
    """ draw """ 
    def general_plot(self, P, d, col, err = None):
        P.plot(d, c = col, alpha = self.alph if err is None else 1)
        if err is not None:
            xax = np.arange(len(d))
            P.fill_between(xax, d-err, d+err, color = col, alpha = self.alph)

    """ labeling """ 
    def plot_postprocess(self, P):
        if self.x0 is not None or self.x1 is not None: 
            X_range = [self.x0 or 0, self.x1 or 1]
            P.set_xlim(X_range)
        if self.y0 is not None or self.y1 is not None: 
            Y_range = [self.y0 or 0, self.y1 or 1]
            P.set_ylim(Y_range)
        P.set_xlabel(self.xlab, fontsize = 20)
        P.set_ylabel(self.ylab, fontsize = 20) 
        P.set_title(f"{self.name} {self.title}", fontsize = 20)
        if self.leg == "PGO":
            handles = [mpatches.Patch(color = self.PGO_COLS[i], label = f"PGO = {self.PGO_range[i] :.1f}") for i in range(self.PGO_N)]
        if self.leg == "names":
            handles = [mpatches.Patch(color = self.NAME_COLS[i], label = f"{name}") for i, name in enumerate(self.names)]
        P.legend(handles = handles)

##########################################################################################################################################################################
    """ Specific plotting """ 
##########################################################################################################################################################################

    """ plot networks wait distributions from last N episodes """ 
    
    def plot_PGO_dists(self, last_N):
        self.last_N = last_N
        self.plot_preprocess(W = 15, H = 30, x1 = 20, title = f"wait from last nogo PDF of last {self.last_N} episodes", xlab = "time", rows = 6)
        self.plot_root(process_mouse = self.all_dists, plot_each = self.PGO_split)
    
    def all_dists(self):
        self.mouse.get_indices(col = 'block', eps_init = max(0, self.mouse.eps[-1]-self.last_N))
        self.mouse.preprocess_dists()
        self.data = getattr(self.mouse, 'split_dist')
        
##########################################################################################################################################################################
    """ plot single PGO waiting distribution from last N episodes""" 
    
    def plot_PGO_dist_repetition(self, last_N):
        self.last_N = last_N
        self.PGO_of_interest = .9
        self.plot_preprocess(W = 15, H = 30, x1 = 20, y1 = .5, alph =.5, title = f"last {self.last_N} PGO 0.9 episodes PDFs", xlab = "time", rows = 6)
        self.plot_root(process_mouse = self.single_dist, plot_each = self.N_split)
        
        self.PGO_of_interest = .5
        self.plot_preprocess(W = 15, H = 30, x1 = 20,  y1 = .5, alph =.5, title = f"last {self.last_N} PGO 0.5 episodes PDFs", xlab = "time", rows = 6)
        self.plot_root(process_mouse = self.single_dist, plot_each = self.N_split)    
    
    def single_dist(self):
        print(f"{self.name} {self.m_i}")
        PGO_inds = np.where(self.mouse.PGOs == self.PGO_of_interest)[0]
        eps = np.unique(self.mouse.eps[PGO_inds])
        self.N = min(self.last_N,  len(eps))
        self.data = np.empty(self.N, dtype = object)
        for n in range(self.N):
            self.mouse.get_indices(col = 'block', eps_init = eps[-n]-1, eps_final = eps[-n] + 1)
            self.mouse.preprocess_dists()
            self.data[n] = getattr(self.mouse, 'split_dist')[0]
            
##########################################################################################################################################################################
    """ plot rew/act/wait for PGO through training """
    
    def plot_through_training(self):
    
        y1s = [1, 20, 15, .07]
        fields = ['episode_rews', 'episode_action_times', 'episode_wait_from_last', 'episode_rew_rate']
        field_names = ["reward probability", "action time", "wait from last nogo", 'reward rate']
        """ this data structure contains all data from fields """
        self.all_data =  np.empty(len(fields), dtype = object)
        

        for i, (self.field, field_name, y1) in enumerate(zip(fields, field_names, y1s)):
            self.plot_preprocess(W = 20, H = 5, x1 = 900, y1 = y1, alph = .5, title = field_name + " (episode average)", xlab = "training episode", ylab = "", rows = 1)            
            # self.plot_root(process_mouse = self.convolve_training, plot_each = self.PGO_split)
            self.data =  np.zeros((self.Ns[self.n_i], self.PGO_N, self.x1))
            self.plot_root(process_mouse = self.convolve_training, plot_all = self.PGO_split)
            self.all_data[i] = self.data 
            

        y0s = [0, 8]
        y1s = [.0003, 10]
        fields = ["R_DW", "R_norm"]
        field_names = ["recurrent weight update", "recurrent weight norm"]
        for self.field, field_name, y0, y1 in zip(fields, field_names, y0s, y1s):
            self.plot_preprocess(W = 20, H = 5, x1 = 5000, y0 = y0, y1 = y1, alph = .5, title = field_name + " (episode average)", xlab = "training episode", ylab = "", rows = 1)        
            self.plot_root(process_mouse = self.convolve_training, plot_each = self.no_split)            
            self.data =  np.zeros((self.Ns[self.n_i], self.x1))
            self.plot_root(process_mouse = self.convolve_training, plot_all = self.no_split)

    def convolve_training(self):
       self.mouse.convolve_data()
       D = getattr(self.mouse, self.field)
       if self.plot_each:
           self.data = D
       if self.plot_all:
           if len(D.shape) == 2:
               self.data[self.m_i] = D[:, :self.x1] 
           else: 
               self.data[self.m_i] = D
        

##########################################################################################################################################################################
    """ plot PCA aligned on action"""
    
    def plot_PCA_from_action(self):
        for self.plant in [False, True]:
            self.plot_preprocess(W = 20, H = 5, x1 = 20, y0 = -6, y1 = 8, alph = .5, title = f"PC mus (episode average) (Planted = {self.plant})", xlab = "time", ylab = "", rows = 1)            
            self.plot_root(process_mouse = self.PCA_on_act, plot_each = self.PGO_split)
            self.data =  np.zeros((self.Ns[self.n_i], self.PGO_N, self.x1))
            self.plot_root(process_mouse = self.PCA_on_act, plot_all = self.PGO_split)
 
    def PCA_on_act(self):
        self.mouse.get_indices(planted = False, postprocess = True)
        self.mouse.run_PCA(plot = False)
        self.mouse.get_indices(col = 'block', From = 0,  Til = 50, stim_above = None, stim_below = None, eps_init = 0,\
            planted = self.plant, plant_PGO = None, plant_ID = None, prev_PGO = None, curr_PGO = None, rew = None, align_on = 'action', postprocess = True, )
        self.mouse.run_trajectory(plot = False)        
        if self.plot_each:
            self.data = self.mouse.PC_mus[0, :, :]
        if self.plot_all:
            self.data[self.m_i, :, :] = self.mouse.PC_mus[0, : , :]
##########################################################################################################################################################################
    """ plot trial from switch split by PGO"""
    
    def plot_wait_from_switch(self):
        self.plot_preprocess(W = 20, H = 5, x1 = 5, y0 = 4, y1 = 13, alph = .5, title = f"wait from last nogo", xlab = "trial from block switch", ylab = "", rows = 1)            
        self.plot_root(process_mouse = self.get_wait_from_switch, plot_each = self.PGO_split)
        self.data =  np.zeros((self.Ns[self.n_i], self.PGO_N, self.x1))
        self.plot_root(process_mouse = self.get_wait_from_switch, plot_all = self.PGO_split)

    def get_wait_from_switch(self):
       self.mouse.get_from_switch()
       D = self.mouse.avg_wait_from_last
       if self.plot_each:
           self.data = D
       if self.plot_all:
           self.data[self.m_i] = D[:, :self.x1] 
##########################################################################################################################################################################
    """ plot trial from switch corr"""
    
    def plot_corr_from_switch(self):
        for self.field in ['PGO_corr', 'last_PGO_corr']:
            self.plot_preprocess(W = 20, H = 5, x1 = 5, y0 = -.2, y1 = 1, alph = .5, title = "PGO-behavior correlation", xlab = "trial from block switch", ylab = "", rows = 1)            
            self.plot_root(process_mouse = self.get_corr_from_switch, plot_each = self.no_split)
            self.data =  np.zeros((self.Ns[self.n_i], self.x1))
            self.plot_root(process_mouse = self.get_corr_from_switch, plot_all = self.no_split)

    def get_corr_from_switch(self):
       self.mouse.get_from_switch()
       D = getattr(self.mouse, self.field)
       if self.plot_each:
           self.data = D
       if self.plot_all:
           self.data[self.m_i] = D[:self.x1] 
           
##########################################################################################################################################################################
    """ plot  avg wait from last for PGO"""
    
    def plot_wait_for_PGO(self):
        for self.field in ['PGO_corr', 'last_PGO_corr']:
            self.plot_preprocess(W = 20, H = 5, title = "PGO-wait", xlab = "PGO", ylab = "", rows = 1, alpha = .5)            
            self.data =  np.zeros((self.Ns[self.n_i], self.PGO_N))
            self.plot_root(process_mouse = self.get_wait_for_PGO, plot_all = self.no_split)

    def get_wait_for_PGO(self):
        self.data[self.m_i, :] = self.mouse.split_dist_avg
        
##########################################################################################################################################################################