from run_simulation import run_simulation; import torch; from Actor_Critic import Actor_Critic;  from mouse_task import mouse_task; import bz2; import pickle; import _pickle as c; from data_structures import cage_data_struct

def manage_data_storage(manager = None, operation = None, path = None, name = None):
    if operation == 'save_behavior':
        data = manager.data.to_dictionary()   
        with bz2.BZ2File(path + name + '.pkbz2', 'w' , compresslevel = 1) as f: 
            c.dump(data, f) 
            
    elif operation == 'save_matlab':
        manager.data.remove_tensors()
        data = manager.data.to_dictionary()   
        with open(path + name + '.pickle', 'wb') as f: 
            pickle.dump(data, f)    
                
    elif operation == 'load_pickle':
        with open(path + name + '.pickle', 'rb') as f: 
            data = c.load(f)
        manager.data.from_dictionary(data)
#        manager.sim.preprocess_data(manager)

    elif operation == 'load_pkbz':
        data = c.load(bz2.BZ2File(path + name + '.pkbz2', 'rb'))
        manager.data.from_dictionary(data)
        manager.sim.preprocess_data(manager)
        
        
cuda = 'cuda:0' 
mode = 'RNN'                      # ['LSTM', 'RNN']
prefix = 'RNN' 
handmade = False
subnets = None          # [None, 'joint', 'isolated']
lesion =  []      # [None, 'INPUT', 'OUTPUT', 'FORGET', 'LTM', 'STM', 'LTM_bound']      # training with bound, testing without bound for OFC
regress_on = ["CELL"] #["CELL", "FORGET", "INPUT", "OUTPUT", "STM", "LTM"]
hid_dim = int(4*48) if mode == 'RNN' else 48 
max_traj_steps = 24

train = False
inp = 'one_hot' #['raw', 'one_hot', 'diff']
discount = .995 
lr = .0001#.00005 for RNN #.001 for subnets

ITI_mean = 15
ITI_PM = 10

act_dim = 2
momentum = 0
ent_bonus = 0
value_beta = .1 
if inp == 'one_hot':
    inp_dim = 5 
if inp == 'raw' or inp == 'diff':
    inp_dim = 3
inp_dim = inp_dim 

net_filename = prefix + '_net.pth'
data_filename =  prefix
device = cuda if torch.cuda.is_available() else 'cpu'
print(f"device = {device}")
from enzyme import PRJ_ROOT
save_path = PRJ_ROOT / 'Data'

loss_params = {'discount' : discount, 'B_val' : value_beta, 'B_ent': ent_bonus, 'decrease_entropy' : False}
network_params = {'inp_dim': inp_dim, 'hid_dim' : hid_dim, 'act_dim' : act_dim, 'device' : device, 'mode' : mode,\
    'lesion': lesion, 'handmade': handmade, 'subnets': subnets, 'RAP' : 0,  'use_vanilla_torch': True, 'train_recurrent' : True}
optim_params = {'lr' : lr, 'alpha' : .99, 'eps' : 1e-5, 'weight_decay' : .0, 'momentum' : momentum, 'centered' : False}
agent_params = dict(loss_params, **optim_params, **network_params)  
agent = Actor_Critic(**agent_params)

mouse_task_params = {'sim_ID' : 3, 'save_path' : save_path, 'act_dim' : act_dim, 'exp_mean': 10, 'exp_max': 50,  'W4L': 35, 'ITI_mean' : ITI_mean, 'ITI_PM' : ITI_PM, 'store_tensors' : False, 'ignore_action' : False, 
     'plant_type' : 'random', 'end_NOGO': False, 'start_NOGO': False, 'regress_on' : regress_on, 'max_traj_steps': max_traj_steps, 'basis_on' : 'output', 'plant_prob': 0, 'neg_rew' : 0, 'theta_traj' : None}
manager_params = {'training' : True, 'input': inp, 'device' : device}  
 
if train:
    """train agent"""
    mouse_task_params['PGO_range'] = [.1, .3, .5, .7, .9]
    mouse_task_params['episodes'] = 5000
    mouse_task_params['num_trials'] = 500
    mouse_task_params['exp_min'] = 4
    training_sim_params = dict(agent_params, **mouse_task_params, **manager_params)            
    training = run_simulation(mouse_task, training_sim_params, agent, plot_episode = True)
    training.sim.plot_all_episodes(training)    
         
    """Save agent"""
    torch.save(agent.state_dict(), save_path / net_filename)
    # manage_data_storage(manager = training, operation = 'save_behavior', path = save_path, name = data_filename)

"""Load agent"""                                                             
load_path = save_path + net_filename 
agent.load_state_dict(torch.load(load_path, map_location=torch.device(device)))

"""Set up testing parameters"""
mouse_task_params['PGO_range'] = [0, .1, .2, .3, .4, .5, .6 , .7, .8, .9]
mouse_task_params['plant_type'] = 'random'
mouse_task_params['ignore_action'] = False
mouse_task_params['end_NOGO'] = True
mouse_task_params['plant_prob'] = .1
mouse_task_params['episodes'] = 2000
mouse_task_params['num_trials'] = 20
mouse_task_params['exp_min'] = 1

"""run testing"""
manager_params['training'] = False
mouse_task_params['store_tensors'] = True
testing_sim_params = dict(agent_params, **mouse_task_params, **manager_params)
testing = run_simulation(mouse_task, testing_sim_params, agent, plot_episode = True)
testing.sim.plot_all_episodes(testing)

"""PCA analysis"""
testing.sim.preprocess_data(manager = testing)
testing.sim.run_PCA()

"""trajectory analysis"""
testing.sim.get_indices(col = 'block', From = 0,  Til = 50, stim_above = None, stim_below = None, eps_init = 0, 
    planted = True, plant_PGO = None, plant_ID = None, prev_PGO = None, curr_PGO = None, rew = None, align_on = 'action', flatten = True)
testing.sim.run_trajectory(plot = True)

""" regression analysis"""
testing.sim.run_REG()
testing.sim.plot_mem_PGO_corr()

mem = 20
testing.sim.get_indices(planted = False, eps_init = testing.sim.held_out, align_on = 'onset', flatten = True)
testing.sim.get_reconstruction(mem)
testing.sim.run_trajectory(plot = False)
testing.sim.plot_bayes_vs_recon(mem)

""" plotting for NOGO bump """
testing.sim.plot_bump_reconstruction(end = 7, mem = mem, plant_ID = None)
testing.sim.plot_update_to_NOGO(mem = mem)

""" dynamics analysis"""
testing.sim.get_indices(col = 'block', From = 0,  Til = 50, stim_above = None, stim_below = None, eps_init = 0,
    planted = True, plant_PGO = None, plant_ID = None, prev_PGO = None, curr_PGO = None, rew = None, align_on = 'action', flatten = True)
testing.sim.get_dynamics(with_baseline = True, with_STM = False, with_LTM = False, effect_of = "FORGET")

# %matplotlib qt # %matplotlib inline