from enzyme.src.mouse_task.flow_fields import flow_fields; from scipy.signal import correlate as cross; import numpy as np; from sklearn.decomposition import PCA; import sklearn; from sklearn.linear_model import LinearRegression, LogisticRegression, TweedieRegressor
# from sklearn.manifold import Isomap; from scipy.sparse import csr_matrix, lil_matrix; import umap
 
"""Class for performing the PCA basis construction and projection"""
class trajectory_processing(flow_fields):
    def __init__(self, **params):
        self.__dict__.update(params)
        
    def run_PCA(self, override_indices = False, map_to = None):
        self.cog_map = map_to
        if override_indices == False:
            self.get_indices(planted = False, flatten = False, prep_mus = False, cross_validation = "training")
        self.get_PCA_basis()
        self.PC_flat = self.proj.transform(self.output_flat.T).T
        self.get_LSTM_weights()
        self.get_theta_error()
        # self.proj_PCs_onto_gates()
        # self.get_corrs()
        # self.plot_gate_eigen()
        # self.plot_corrs()   
            
    def get_corrs(self):
        self.inp_cross = np.zeros((3, self.PC_dim, self.max_traj_steps))
        self.auto_corrs = np.zeros((3, self.max_traj_steps))
        inputs = self.input_flat[[1,3,4], :] #GO/LICK/REW
        PCs = self.proj.transform(self.output_flat.T).T
        self.inp_corrs = self.correlate(X = inputs, Y = PCs)
        for self.dim in range(self.PC_dim):
            self.get_PC_cross(inputs, PCs)

    def get_PC_cross(self, inputs, PCs):
        Y = PCs[self.dim, :]
        Y = (Y - Y.mean())/Y.var()
        keep = int(len(Y))
        offset = int(self.max_traj_steps/2)
        for inp_dim in range(inputs.shape[0]):   
            X = inputs[inp_dim,:] 
            X = (X - X.mean())/X.var()
            inp = cross(X, Y)[keep - offset : keep + offset]
            auto = cross(Y, Y)[keep - offset : keep + offset]
            self.inp_cross[inp_dim, self.dim, :] = inp/max(abs(inp))
            self.auto_corrs[self.dim, :] = auto/max(abs(auto))
                
    def run_trajectory(self, plot = True):
        self.get_trajectory_mus()
        if plot:
            self.plot_trajectories()
            self.plot_on_PCA_basis()        

    def get_PCA_basis(self):
            self.LTM_flat = self.flatten_trajectory(self.LTM[self.trial_inds])
            self.Q_flat = self.flatten_trajectory(self.Q_values[self.trial_inds])
            self.PGO_flat = self.flatten_trajectory(self.step_PGO[self.trial_inds])
            self.plant_PGO_flat = self.flatten_trajectory(self.step_plant_PGO[self.trial_inds])
            self.PSAFE_flat = self.flatten_trajectory(self.safe_backbone[self.trial_inds])
            self.C_gate_flat = self.flatten_trajectory(self.c_gate[self.trial_inds])
            self.O_gate_flat = self.flatten_trajectory(self.o_gate[self.trial_inds])
            self.I_gate_flat = self.flatten_trajectory(self.i_gate[self.trial_inds])
            self.F_gate_flat = self.flatten_trajectory(self.f_gate[self.trial_inds])
            self.input_flat = self.flatten_trajectory(self.net_input[self.trial_inds])
            self.output_flat = self.flatten_trajectory(self.net_output[self.trial_inds])
            self.C_gate_flat = self.flatten_trajectory(self.c_gate[self.trial_inds])
            self.consec_flat = self.flatten_trajectory(self.consec_stim[self.trial_inds])
            data = [np.random.randn(self.hid_dim, self.hid_dim), self.LTM_flat, self.output_flat]
            names = ["rand", "LTM", "output"]
            x = self.get_data_of(data, names, select = self.basis_on)
            self.proj = PCA(n_components = self.PC_dim)
            self.proj.fit(np.transpose(x))
            self.basis = self.proj.components_
            self.data_mean = x.mean(1)
            print(f"explained variance: {self.proj.explained_variance_ratio_}")
            self.PC_flat = self.proj.transform(x.T).T
            self.PC_diff = np.diff(self.PC_flat, prepend=0)
            self.get_reg_vec()
            self.Q_reg = LogisticRegression().fit(self.pred_from.T, self.PSAFE_flat)
            self.PGO_reg = LinearRegression().fit(self.pred_from.T, self.PGO_flat)
            self.theta_dist_flat = np.vstack(np.hstack(self.theta_dist_structured[self.trial_inds])).T
            self.belief_dist_flat = np.vstack(np.hstack(self.belief_dist_structured[self.trial_inds])).T
            self.belief_STD = self.dist_to_std(self.belief_dist_flat, np.arange(0,3, 1)[:,None])
            self.theta_STD =  self.dist_to_std(self.theta_dist_flat, np.linspace(0,1, self.bayes_resolution)[:,None])
            self.theta_std_reg = LinearRegression().fit(self.pred_from.T, self.theta_STD)
            self.belief_std_reg = LinearRegression().fit(self.pred_from.T, self.belief_STD)
            self.theta_std_reg = LinearRegression().fit(self.pred_from.T, self.theta_STD)
            self.belief_std_reg = LinearRegression().fit(self.pred_from.T, self.belief_STD)
            self.flow_belief_flat = self.flatten_trajectory(self.flow_belief_structured[self.trial_inds])
            self.flow_theta_flat = self.flatten_trajectory(self.flow_theta_structured[self.trial_inds])
            self.bayes_psafe_non_log = 1-np.exp(-self.flow_belief_flat.astype(float))
            self.bayes_belief_reg = LinearRegression().fit(self.pred_from.T, self.bayes_psafe_non_log)
            self.bayes_theta_reg = LinearRegression().fit(self.pred_from.T, self.flow_theta_flat)

    def get_RMSE(self):
        self.get_indices(planted = False, flatten = True, prep_mus = True, cross_validation = "testing")
        self.belief_STD = self.dist_to_std(self.belief_dist_flat, np.arange(0,3, 1)[:,None])
        self.theta_STD =  self.dist_to_std(self.theta_dist_flat, np.linspace(0,1, self.bayes_resolution)[:,None])
        self.flow_theta_flat = self.flatten_trajectory(self.flow_theta_structured[self.trial_inds])
        self.flow_belief_flat = self.flatten_trajectory(self.flow_belief_structured[self.trial_inds])
        self.bayes_psafe_non_log = 1-np.exp(-self.flow_belief_flat.astype(float))
        self.state_RMSE = np.sqrt(np.mean((self.PSAFE_flat-self.Q_reg.predict_proba(self.pred_from.T)[:, 1])**2))#  / self.PSAFE_flat.std()
        self.theta_RMSE = np.sqrt(np.mean((self.PGO_flat-self.PGO_reg.predict(self.pred_from.T))**2))#   / self.PGO_flat.std()
        self.state_STD_RMSE = np.sqrt(np.mean((self.belief_STD-self.belief_std_reg.predict(self.pred_from.T))**2))#  / self.belief_STD.std()
        self.theta_STD_RMSE = np.sqrt(np.mean((self.theta_STD-self.theta_std_reg.predict(self.pred_from.T))**2))#   / self.theta_STD.std()
        self.bayes_state_RMSE = np.sqrt(np.mean((self.bayes_psafe_non_log-self.Q_reg.predict_proba(self.pred_from.T)[:, 1])**2))#    / self.bayes_psafe_non_log.std()
        self.bayes_theta_RMSE = np.sqrt(np.mean((self.flow_theta_flat-self.PGO_reg.predict(self.pred_from.T))**2))#   / self.flow_theta_flat.std()
        self.bayes_state_reg_RMSE = np.sqrt(np.mean((self.bayes_psafe_non_log-self.bayes_belief_reg.predict(self.pred_from.T))**2))#   / self.bayes_psafe_non_log.std()
        self.bayes_theta_reg_RMSE = np.sqrt(np.mean((self.flow_theta_flat-self.bayes_theta_reg.predict(self.pred_from.T))**2))#   / self.flow_theta_flat.std()
        print("state RMSE ", self.state_RMSE)
        print("theta RMSE ", self.theta_RMSE)
        print("state STD RMSE ", self.state_STD_RMSE)
        print("theta STD RMSE ", self.theta_STD_RMSE)
        print("bayes state RMSE ", self.bayes_state_RMSE)
        print("bayes theta RMSE ", self.bayes_theta_RMSE)
        print("bayes state reg RMSE ", self.bayes_state_reg_RMSE)
        print("bayes theta reg RMSE ", self.bayes_theta_reg_RMSE)
        """ logistic regression to full PGO distribution """
        # PGO_label = np.where(self.PGO_flat[:, None] == self.PGO_range[None,:])[1]
        # self.decoded_dist_reg = LogisticRegression().fit(self.pred_from.T, PGO_label)

    def dist_to_std(self, dist, xax):
        X = dist*xax
        return np.mean(X**2, 0) - np.mean(X, 0)**2       
        
    def get_theta_error(self):
        self.bayes_mse, self.net_mse, self.net_ratio, self.bayes_ratio = [np.zeros((self.PGO_N, 2)) for _ in range(4)]
        self.MF_mse = np.zeros((self.MF_num, self.PGO_N, 2))
        for trial_i, trial in enumerate([0, self.num_trials-1]):
            self.get_indices(col = 'block', From = trial, Til = trial,  flatten = True)
            for self.split_i, self.split_curr in enumerate(self.split):
                self.get_split_inds()
                inds = self.full_split_inds_flat.astype(int)
                self.bayes_mse[self.split_i, trial_i] = ((self.flow_theta_flat[inds] - self.split_curr)**2).mean()
                self.net_mse[self.split_i, trial_i] = ((self.net_theta_flat[inds] - self.split_curr)**2).mean()
                self.MF_mse[:, self.split_i, trial_i] = ((self.MF_theta_flat[:, inds] - self.split_curr)**2).mean(-1)
                opt_MF = np.argmin(self.MF_mse[:, self.split_i, trial_i])
                self.net_ratio[self.split_i, trial_i] = self.net_mse[self.split_i, trial_i] / self.MF_mse[opt_MF, self.split_i, trial_i] 
                self.bayes_ratio[self.split_i, trial_i] = self.bayes_mse[self.split_i, trial_i] / self.MF_mse[opt_MF, self.split_i, trial_i] 
                        
        
    """ get mean trajectories """

    def get_trajectory_mus(self):
        for self.split_i, self.split_curr in enumerate(self.split):
            self.get_split_inds()
            self.get_trajectory_mu_per_step()
            self.postprocess_mus()

    def get_trajectory_mu_per_step(self):
        for self.step in range(self.max_traj_steps):   
            self.get_trajectory_step_inds()
            self.get_net_mus()
   
    def postprocess_mus(self):
        no_prob = self.emp_act_prob[self.split_i, :].sum() == 0
        self.emp_act_prob[self.split_i, :] = 0 if no_prob else self.emp_act_prob[self.split_i, :]/sum(self.emp_act_prob[self.split_i, :])
        max_prob = 1e-3 + np.max(self.emp_act_prob[self.split_i, :])
        # self.act_col[self.split_i, :,:] = self.Cmap.to_rgba(self.emp_act_prob[self.split_i, :]/max_prob)
        self.activities = [self.I_gate_mus, self.F_gate_mus, self.C_gate_mus, self.O_gate_mus, self.LTM_mus, self.output_mus]
        self.PC_traj_diffs[:, self.split_i, :] = np.diff(self.PC_mus[:, self.split_i, :], prepend = 0, axis = -1)
        
    def get_net_mus(self):
        if len(self.step_inds) > 0:
            self.PGO_mus[self.split_i, self.step] = self.PGO_flat[self.step_inds].mean()
            self.GO_mus[self.split_i, self.step] = self.input_flat[1, self.step_inds].mean(-1)

            """ for without full tensors: training data """
            # self.lick_prob_mus[self.split_i, self.step] = self.value_mus[self.split_i, self.step] = self.QDIFF_mus[self.split_i, self.step] =  0
            
            self.lick_prob_mus[self.split_i, self.step] = self.lick_prob_flat[self.step_inds].mean(-1)
            self.value_mus[self.split_i, self.step] = self.value_flat[self.step_inds].mean(-1)
            self.QDIFF_mus[self.split_i, self.step] =  (self.Q_flat[1, self.step_inds] - self.Q_flat[0, self.step_inds]).mean(-1) 
            self.pred_from_mus[:, self.split_i, self.step] = self.pred_from[:, self.step_inds].mean(-1)

            self.output_mus[:, self.split_i, self.step] = self.output_flat[:, self.step_inds].mean(-1)
            self.LTM_mus[:, self.split_i, self.step] = self.LTM_flat[:, self.step_inds].mean(-1)
            self.I_gate_mus[:, self.split_i, self.step] = self.I_gate_flat[:, self.step_inds].mean(-1)
            self.F_gate_mus[:, self.split_i, self.step] = self.F_gate_flat[:, self.step_inds].mean(-1)
            self.O_gate_mus[:, self.split_i, self.step] = self.O_gate_flat[:, self.step_inds].mean(-1)
            self.C_gate_mus[:, self.split_i, self.step] = self.C_gate_flat[:, self.step_inds].mean(-1)
            
            self.PC_mus[:, self.split_i, self.step] = self.PC_flat[:, self.step_inds].mean(-1)    
        
            
            self.flow_theta_mus[self.split_i, self.step]  = self.flow_theta_flat[self.step_inds].mean(-1)
            self.flow_belief_mus[self.split_i, self.step]  = self.flow_belief_flat[self.step_inds].mean(-1)
            self.flow_belief_vars[self.split_i, self.step]  = self.flow_belief_flat[self.step_inds].var(-1)
            self.flow_theta_vars[self.split_i, self.step]  = self.flow_theta_flat[self.step_inds].var(-1)
            self.factorized_theta_mus[self.split_i, self.step]  = self.factorized_theta_flat[self.step_inds].mean(-1)
            self.factorized_belief_mus[self.split_i, self.step]  = self.factorized_belief_flat[self.step_inds].mean(-1)
            self.factorized_belief_vars[self.split_i, self.step]  = self.factorized_belief_flat[self.step_inds].var(-1)
            self.factorized_theta_vars[self.split_i, self.step]  = self.factorized_theta_flat[self.step_inds].var(-1)            
            self.net_theta_mus[self.split_i, self.step]  = self.net_theta_flat[self.step_inds].mean(-1)
            self.net_belief_mus[self.split_i, self.step]  = self.net_belief_flat[self.step_inds].mean(-1)            
            self.net_theta_vars[self.split_i, self.step]  = self.net_theta_flat[self.step_inds].var(-1)
            self.net_belief_vars[self.split_i, self.step]  = self.net_belief_flat[self.step_inds].var(-1)            
        
        
    """ LSTM projections """  
            
    def proj_PCs_onto_gates(self):
        self.gate_proj = np.zeros((self.PC_dim, 4, self.hid_dim))
        self.act_proj = self.basis @ self.normed_Wa.T
        for gate_i, gate in enumerate(self.mode_normed_weights):
            self.gate_proj[:, gate_i, :] = self.basis @ gate
            # self.gate_proj[:, gate_i, :] = self.proj.transform(gate).T
        # self.plot_net_angles()
        
    def get_LSTM_weights(self):
        self.Wa = self.net['to_action.weight'];
        W = self.net['actor.hgates.weight']
        B = self.net['actor.hgates.bias']
        self.Wf, self.Wi, self.Wc, self.Wo = (W,W,W,W) if self.mode == 'RNN' else W.chunk(4,0)
        self.Bf, self.Bi, self.Bc, self.Bo = (B,B,B,B) if self.mode == 'RNN' else B.chunk(4,0)
        self.Wa_cpu, self.Wi_cpu, self.Wo_cpu, self.Wc_cpu, self.Wf_cpu,\
        self.Wa_cpu, self.Bi_cpu, self.Bo_cpu, self.Bc_cpu, self.Bf_cpu =\
            self.to([self.Wa, self.Wi, self.Wo, self.Wc, self.Wf,\
            self.Wa, self.Bi, self.Bo, self.Bc, self.Bf], 'np')
        self.normed_Wa = self.norm_W(self.Wa_cpu)
        self.normed_Wi = self.norm_W(self.Wi_cpu)
        self.normed_Wf = self.norm_W(self.Wf_cpu)
        self.normed_Wc = self.norm_W(self.Wc_cpu) 
        self.normed_Wo = self.norm_W(self.Wo_cpu)
        self.mode_normed_weights = [self.normed_Wi, self.normed_Wf, self.normed_Wc, self.normed_Wo]
        self.mode_weights = [self.Wi_cpu, self.Wf_cpu, self.Wc_cpu, self.Wo_cpu]