from enzyme.src.main.run_simulation import run_simulation; from enzyme.src.network.Actor_Critic import Actor_Critic; 
from enzyme.src.mouse_task.mouse_task import mouse_task; from enzyme.src.mouse_task.data_structures import cage_data_struct; import pylab as plt; import numpy as np
import bz2; import pickle; import _pickle as c;   import copy; import pylab as plt; import matplotlib.patches as mpatches
from enzyme.src.mouse_task.data_compare_processor import data_compare_processor

class compare_data_struct(data_compare_processor): 
    def __init__(self, cage, testing, episodes = None, trials = None):
        
        self.cage = cage
        self.cuda = 'cuda:1'
        self.testing = testing
        self.testing_trials = trials
        self.testing_episodes = episodes
        self.names = ["RNN", "LSTM"]        
        self.Ns = [cage.get_num("RNN"),cage.get_num("LSTM")]
        self.mice, self.each_name = [np.empty(self.Ns[0] + self.Ns[1], dtype = object) for _ in range(2)]       
        self.fill_processed_nets()

    def fill_processed_nets(self, i = 0):
        for self.n_i, self.name in enumerate(self.names):
            for self.m_i in range(self.cage.get_num(self.name)):
                print(f"{self.name} {self.m_i}")
                self.load_manager()
                self.mice[i] = self.M.sim
                del(self.M.sim.data)
                del(self.M)
                self.each_name[i] = self.name
                i += 1
        self.NAME_COLS = ["C0", "C1"]
        self.PGO_COLS = self.mice[0].PGO_COLOR
        self.PGO_range = self.mice[0].PGO_range
                    
    def load_manager(self):
        self.load_agent()
        if self.testing:
            print(f"{self.name} mouse {self.m_i}")
            self.prep_test_mouse()
        else:
            self.prep_training_mouse()
        self.M.sim.get_PGO_standards()            

    def load_agent(self):
        hid_dim = int(4*48) if self.name == "RNN" else 48
        self.network_params = {'inp_dim': 5, 'hid_dim' : hid_dim, 'act_dim' : 2, 'device' : self.cuda, 'mode' : self.names,\
        'lesion': [], 'handmade': True, 'subnets': None, 'RAP' : 0, 'use_vanilla_torch': False, 'train_recurrent' : True}
        
        self.loss_params = {'discount' : 0, 'B_val' : 0, 'B_ent': 0, 'decrease_entropy' : False, 'context' : False}
        self.optim_params = {'lr' : 0, 'alpha' : 0, 'eps' : 0, 'weight_decay' : 0, 'momentum' : 0, 'centered' : False}
        self.network_params['mode'] = self.name 
        self.agent_params = dict(self.loss_params, **self.optim_params, **self.network_params)  
        self.agent = Actor_Critic(**self.agent_params)
        self.agent.load_state_dict(self.cage.get_data(self.name, self.m_i, 'network'))
        
    def prep_training_mouse(self):
        self.mouse_task_params = {'sim_ID' : 3, 'save_path' : None, 'act_dim' : 0, 'exp_mean': 0, 'exp_max': 0,  'W4L': 1, 'skip' : True, 'input': 'one_hot',
            'ITI_mean' : 1, 'ITI_PM' : 0, 'plant_type': 'random', 'plant_prob' : 0, 'ignore_action': False, 'episodes' : 2, 'num_trials' : 2,
            'exp_min': 0, 'silent_ITI' : False, 'time_cost' : 0, 'plant_type' : 'random', 'sensory_noise' : 0, 'motor_delay' : 0,
            'regress_on' : [], 'max_traj_steps': 0, 'basis_on' : "output", 'PGO_range' : [0, .1, .2, .3, .4, .5, .6 , .7, .8, .9],
            'neg_rew' : 0, 'theta_traj' : None, 'start_NOGO': False, 'end_NOGO': False, 'store_tensors' : False}
        self.manager_params = {'training' : False, 'context': None, 'manage_streams' : None, 'device' : self.cuda}    
        self.manager_sim_params = dict(self.agent_params, **self.mouse_task_params, **self.manager_params)
        self.M = run_simulation(mouse_task, self.manager_sim_params, self.agent, plot_episode = False)
        self.M.sim.preprocess_shallow_data(self.cage.get_data(self.name, self.m_i))
        self.PGO_N = self.M.sim.PGO_N
        
    def prep_test_mouse(self):
        self.mouse_task_params = {'sim_ID' : 3, 'save_path' : None, 'act_dim' : 2, 'exp_mean': 10, 'exp_max': 50,  'W4L': 35, 'skip' : True, 'input': 'one_hot',
            'ITI_mean' : 15, 'ITI_PM' : 10, 'plant_type': 'random', 'plant_prob' : 0, 'ignore_action': False, 'episodes' : self.testing_episodes, 'num_trials' : self.testing_trials,
            'exp_min': 0, 'silent_ITI' : False, 'time_cost' : 0, 'plant_type' : 'random', 'sensory_noise' : 0, 'motor_delay' : 0,
            'regress_on' : ["CELL", "FORGET", "INPUT", "OUTPUT"], 'max_traj_steps': 20, 'basis_on' : "output", 'PGO_range' : [.1, .3, .5, .7, .9], 
            'neg_rew' : 0, 'theta_traj' : None, 'start_NOGO': False, 'end_NOGO': False,  'store_tensors' : False}
        self.manager_params = {'training' : False, 'context': None, 'manage_streams' : None, 'device' : self.cuda}        
        self.manager_sim_params = dict(self.agent_params, **self.mouse_task_params, **self.manager_params)
        self.M = run_simulation(mouse_task, self.manager_sim_params, self.agent, plot_episode = False)
        self.M.data.postprocess_specials()
        self.M.sim.preprocess_shallow_data(self.M.data)
        self.M.sim.preprocess_deep_data(self.M)
        self.PGO_N = self.M.sim.PGO_N
        
##########################################################################################################################################################################
    """ helpers """ 
##########################################################################################################################################################################
    def get_mice(self):
        return self.mice[np.where(self.each_name == self.name)[0]]

    def do(self, do = None):
        if do is None:
            pass
        else:
            do()

    def get_figure(self):
        fig, self.ax = plt.subplots(self.rows , 2,figsize = (self.W, self.H), tight_layout = True);
            
    def get_ax(self):
        if self.rows  == self.Ns[self.n_i]:                                          # if plot for each mouse 
            return self.ax[self.m_i, self.n_i]
        if self.rows == 1:                                                           # if plot for each network type
            return self.ax[self.n_i] 

    def get_mu_var(self, D):
        return D.mean(0), D.std(0)
    
    def plot_preprocess(self, W = 15, H = 10, x0 = None, x1 = None, y0 = None, y1 = None, alph = 1, title = "", xlab = "", ylab = "", rows = None):
        self.W = W; self.H = H; self.x0 = x0; self.x1 = x1; self.y0 = y0; self.y1 = y1; self.alph = alph; 
        self.title = title; self.xlab = xlab; self.ylab = ylab; self.rows = rows
        

# ##########################################################################################################################################################################

if __name__ == '__main__':
    path = '/home/johns/anaconda3/envs/PFC_env/PFC/Data/training_cage' 
    print("loading mice")
    cage = cage_data_struct(location = path, load = True)
    print("anesthisizing mice")   
    C = compare_data_struct(cage, testing = False)
    # C.plot_PGO_dists(last_N = 50)
    # # C.plot_PGO_dist_repetition(last_N = 50)
    print("analyzing mice")   
    C.plot_through_training()
    
    # C = compare_data_struct(cage, testing = True, episodes = 50, trials = 10)
    # C = compare_data_struct(cage, testing = True, episodes = 500, trials = 20)
    # C = compare_data_struct(cage, testing = True, episodes = 50, trials = 20)
    # C.plot_PGO_dists(last_N = 500)
    # C.plot_PGO_dist_repetition(last_N = 50)
    # C.plot_PCA_from_action()
    # C.plot_wait_from_switch()
    # C.plot_corr_from_switch()
    # C.plot_wait_for_PGO()