# Loading in these plots 5 different calues of alpha, generated by the same code and only differing for the parameter rescaling the state entorpy term 

import numpy as np
import random
import math
import torch
import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt
import dill
import seaborn as sns

#%% OPENING THE ALPHA = 0 CASE

name_file = 'Figure5_alpha00'

with open(name_file + '_training.pkl', 'rb') as file:
    data = dill.load(file)
    
_, _, STD_00, STD_single_00, Var_00, Var_single_00, Delta_var_00, T_end_00, ED_00, ED_thresh_00, ED_state_00, _, _, _ = data

with open(name_file + '_activity.pkl', 'rb') as file:
    data = dill.load(file)

Activity_00, _, Energy_00, _ = data

#%% OPENING THE ALPHA = 0.2 CASE
name_file = 'Figure5_alpha02'

with open(name_file + '_training.pkl', 'rb') as file:
    data = dill.load(file)
    
_, _, STD_02, STD_single_02, Var_02, Var_single_02, Delta_var_02, T_end_02, ED_02, ED_thresh_02, ED_state_02, _,  _, _ = data

with open(name_file + '_activity.pkl', 'rb') as file:
    data = dill.load(file)

Activity_02, _, Energy_02, _ = data

#%% OPENING THE ALPHA = 0.5 CASE
name_file = 'Figure5_alpha05'

with open(name_file + '_training.pkl', 'rb') as file:
    data = dill.load(file)
    
_, _, STD_05, STD_single_05, Var_05, Var_single_05, Delta_var_05, T_end_05, ED_05, ED_thresh_05, ED_state_05, _,  _, _ = data

with open(name_file + '_activity.pkl', 'rb') as file:
    data = dill.load(file)

Activity_05, _, Energy_05, _ = data

#%% OPENING THE ALPHA = 0.8 CASE

name_file = 'Figure5_alpha08'

with open(name_file + '_training.pkl', 'rb') as file:
    data = dill.load(file)
    
_, _, STD_08, STD_single_08, Var_08, Var_single_08, Delta_var_08, T_end_08, ED_08, ED_thresh_08, ED_state_08, _,  _, _ = data

with open(name_file + '_activity.pkl', 'rb') as file:
    data = dill.load(file)

Activity_08, _, Energy_08, _ = data

#%% OPENING THE ALPHA = 1.5 CASE

name_file = 'Figure5_alpha15'

with open(name_file + '_training.pkl', 'rb') as file:
    data = dill.load(file)

params, MSE, STD_15, STD_single_15, Var_15,Var_single_15, Delta_var_15, T_end_15, ED_15, ED_thresh_15, ED_state_15, ED_state_free, rnn, fnn = data

with open(name_file + '_activity.pkl', 'rb') as file:
    data = dill.load(file)

Activity_15, Activity_free, Energy_15, Energy_free = data


#%% PLOT
pal_pre = ['D45187', 'B4326C', '950852', '76003A', '580023', '6c757d']
pal_pre = [f"#{c}" for c in pal_pre]

pal_grey = ['f8f9fa', 'e9ecef', 'dee2e6', 'ced4da', 'adb5bd', '6c757d', '495057', '343a40']
pal_grey = [f"#{c}" for c in pal_grey]

plt.rcParams['text.usetex'] = True
plt.rcParams['mathtext.fontset'] = 'stix'
plt.rcParams['font.family'] = 'STIXGeneral'

# set default style
sns.reset_defaults() # useful when adjusting style a lot
sns.set_theme(context="paper", style="ticks",
              # palette="Set2",
              palette=pal_blue,
              rc={
              "pdf.fonttype": 42,  # embed font in output
              "svg.fonttype": "none",  # embed font in output
              "figure.facecolor": "white",
              "figure.dpi": 100,
              "axes.facecolor": "None",
              "axes.spines.left": True,
              "axes.spines.bottom": True,
              "axes.spines.right": False,
              "axes.spines.top": False,
          },
          )

sns.color_palette(pal_blue) 
#%% OCCUPANCY
N_bins = 50 
min_bins = -1
max_bins = 1
delta_bins = (max_bins - min_bins)/N_bins

Occupancy_free = np.zeros((N_bins, rnn.t))
Occupancy_00 = np.zeros((N_bins, rnn.t))
Occupancy_02 = np.zeros((N_bins, rnn.t))
Occupancy_05 = np.zeros((N_bins, rnn.t))
Occupancy_08 = np.zeros((N_bins, rnn.t))
Occupancy_15 = np.zeros((N_bins, rnn.t))

for av in range(len(Activity_00[:,0,0])):
    for tt in range(rnn.t):
        for bb in range(N_bins):
            for nn in range(rnn.N):
                if min_bins + bb * delta_bins < Activity_free[av, nn, tt] < min_bins + (bb+1) * delta_bins:
                    Occupancy_free[bb, tt] += 1
                if min_bins + bb * delta_bins < Activity_00[av, nn, tt] < min_bins + (bb+1) * delta_bins:
                    Occupancy_00[bb, tt] += 1
                if min_bins + bb * delta_bins < Activity_02[av, nn, tt] < min_bins + (bb+1) * delta_bins:
                    Occupancy_02[bb, tt] += 1
                if min_bins + bb * delta_bins < Activity_05[av, nn, tt] < min_bins + (bb+1) * delta_bins:
                    Occupancy_05[bb, tt] += 1
                if min_bins + bb * delta_bins < Activity_08[av, nn, tt] < min_bins + (bb+1) * delta_bins:
                    Occupancy_08[bb, tt] += 1
                if min_bins + bb * delta_bins < Activity_15[av, nn, tt] < min_bins + (bb+1) * delta_bins:
                    Occupancy_15[bb, tt] += 1
          
Delta_00 = np.mean( np.mean( Delta_var_00, axis = 1) , axis = 0)
Delta_02 = np.mean( np.mean( Delta_var_02, axis = 1) , axis = 0)
Delta_05 = np.mean( np.mean( Delta_var_05, axis = 1) , axis = 0)
Delta_08 = np.mean( np.mean( Delta_var_08, axis = 1) , axis = 0)
Delta_15 = np.mean( np.mean( Delta_var_15, axis = 1) , axis = 0)

Delta_00_err = np.std( np.mean( Delta_var_00, axis = 1) , axis = 0) / np.sqrt(Delta_var_00.shape[0])
Delta_02_err = np.std( np.mean( Delta_var_02, axis = 1) , axis = 0) / np.sqrt(Delta_var_02.shape[0])
Delta_05_err = np.std( np.mean( Delta_var_05, axis = 1) , axis = 0) / np.sqrt(Delta_var_05.shape[0])
Delta_08_err = np.std( np.mean( Delta_var_08, axis = 1) , axis = 0) / np.sqrt(Delta_var_08.shape[0])
Delta_15_err = np.std( np.mean( Delta_var_15, axis = 1) , axis = 0) / np.sqrt(Delta_var_15.shape[0])


#%% PLOTTING IN A SINGLE PLOT

#Chooose one random agent you want to show
traj = 9

fig, axes = plt.subplot_mosaic([
    ["Statefree", "Statefree" , "E_Statefree", "O_Statefree", "T_end"],
    ["State00", "State00" , "E_State00", "O_State00", "STD_delta"],
    ["State08", "State08", "E_State08", "O_State08", "STD_single"],
    ["State15", "State15", "E_State15", "O_State15", "ED_action"]
    ], gridspec_kw=dict(hspace=0.6, 
                      wspace=0.8,
                      width_ratios=[0.6, 0.4, 1, 0.6, 0.8 ]),)
                    

axes["State00"].plot(torch.arange(0, rnn.t), Activity_00[traj, :20, :].T, alpha=0.9, color = pal_pre[0]) 
axes["State00"].set_xlabel('t', fontsize = 12)
axes["State00"].set_xlim(0, 600)
axes["State00"].set_ylabel('x(t)', fontsize = 12) 
# axes["State00"].set_title(r'$\beta = 0$', fontsize = 12)                        
axes['O_State00'].barh(np.arange(N_bins), np.mean(Occupancy_00, axis=1) / rnn.N, height=1, edgecolor='none', color=pal_pre[0])
axes['O_State00'].set_yticks([0, 25, 50], ['-1', '0', '1'], fontsize = 12)
axes['O_State00'].set_xticks([], [])
axes['O_State00'].set_xlabel('pdf', fontsize = 12)
axes['O_State00'].set_ylabel('x', fontsize = 12)
axes["E_State00"].plot(torch.arange(0, rnn.t), Energy_00[traj, 0, :], alpha=0.9, color = pal_pre[0]) 
axes["E_State00"].plot(torch.arange(0, rnn.t), rnn.x_th_plus*torch.ones(rnn.t), linestyle='--', linewidth=1.5, color='red')
axes["E_State00"].set_xlabel('t', fontsize = 12)
axes["E_State00"].set_ylabel('E(t)', fontsize = 12) 
axes["E_State00"].set_xlim(0, 600)
axes["E_State00"].set_ylim(0.06, 0.14)
axes["State00"].set_xticks([0, 600],['0', '600'], fontsize = 12)
axes["State00"].set_yticks([-1, 0, 1],['-1','0','1'], fontsize = 12)
axes["E_State00"].set_xticks([0, 600],['0', '600'], fontsize = 12)
axes["E_State00"].set_yticks([0.06,0.11, 0.14],['0.06','0.11', '0.14'], fontsize = 12)


axes["Statefree"].plot(torch.arange(0, rnn.t), Activity_free[traj, :20, :].T, alpha=0.9, color = pal_pre[-1]) 
axes["Statefree"].set_xlabel('t', fontsize = 12)
axes["Statefree"].set_xlim(0, 600)
axes["Statefree"].set_ylabel('x(t)', fontsize = 12) 
# axes["Statefree"].set_title('Free', fontsize = 12)                        
axes['O_Statefree'].barh(np.arange(N_bins), np.mean(Occupancy_free, axis=1) / rnn.N,  edgecolor='none', height=1, color=pal_pre[-1])
axes['O_Statefree'].set_yticks([0, 25, 50], ['-1', '0', '1'], fontsize = 12)
axes['O_Statefree'].set_xticks([], [])
axes['O_Statefree'].set_xlabel('pdf', fontsize = 12)
axes['O_Statefree'].set_ylabel('x', fontsize = 12)
axes["E_Statefree"].plot(torch.arange(0, rnn.t), Energy_free[traj, 0, :], alpha=0.9, color = pal_pre[-1]) 
axes["E_Statefree"].plot(torch.arange(0, rnn.t), rnn.x_th_plus*torch.ones(rnn.t), linestyle='--', linewidth=1.5, color='red')
axes["E_Statefree"].set_xlabel('t', fontsize = 12)
axes["E_Statefree"].set_ylabel('E(t)', fontsize = 12) 
axes["E_Statefree"].set_xlim(0, 600)
axes["E_Statefree"].set_ylim(0.06, 0.14)
axes["Statefree"].set_xticks([0, 600],['0', '600'], fontsize = 12)
axes["Statefree"].set_yticks([-1, 0, 1],['-1','0','1'], fontsize = 12)
axes["E_Statefree"].set_xticks([0, 600],['0', '600'], fontsize = 12)
axes["E_Statefree"].set_yticks([0.06,0.11, 0.14],['0.06','0.11', '0.14'], fontsize = 12)




axes["State08"].plot(torch.arange(0, rnn.t), Activity_08[traj, :20, :].T, alpha=0.9, color = pal_pre[2]) 
axes["State08"].set_xlabel('t', fontsize = 12)
axes["State08"].set_xlim(0, 600)
axes["State08"].set_ylabel('x(t)', fontsize = 12) 
# axes["State08"].set_title(r'$\beta = 0.8$', fontsize = 12) 
axes["State08"].set_yticks([-1, 0, 1],['-1','0','1'], fontsize = 12) 
axes["State08"].set_xticks([0, 600],['0', '600'], fontsize = 12)                                        
axes['O_State08'].barh(np.arange(N_bins), np.mean(Occupancy_08, axis=1) / rnn.N, edgecolor='none', height=1, color=pal_pre[2])
axes['O_State08'].set_yticks([0, 25, 50], ['-1', '0', '1'], fontsize = 12)
axes['O_State08'].set_xticks([], [])
axes['O_State08'].set_ylabel('x', fontsize = 12)
axes['O_State08'].set_xlabel('pdf', fontsize = 12)
axes["E_State08"].plot(torch.arange(0, rnn.t), Energy_08[traj, 0, :], alpha=0.9, color = pal_pre[2]) 
axes["E_State08"].plot(torch.arange(0, rnn.t), rnn.x_th_plus*torch.ones(rnn.t), linestyle='--', linewidth=1.5, color='red')
axes["E_State08"].set_xlabel('t', fontsize = 12)
axes["E_State08"].set_ylabel('E(t)', fontsize = 12) 
axes["E_State08"].set_xticks([0, 600],['0', '600'], fontsize = 12)
axes["E_State08"].set_yticks([0.06,0.11, 0.14],['0.06','0.11', '0.14'], fontsize = 12)


axes["State15"].plot(torch.arange(0, rnn.t), Activity_15[traj, :20, :].T, alpha=0.9, color = pal_pre[3]) 
axes["State15"].set_xlabel('t', fontsize = 12)
axes["State15"].set_xlim(0, 600)
axes["State15"].set_ylabel('x(t)', fontsize = 12) 
# axes["State15"].set_title(r'$\beta = 1.5$', fontsize = 12)                        
axes['O_State15'].barh(np.arange(N_bins), np.mean(Occupancy_15, axis=1) / rnn.N, edgecolor='none', height=1, color=pal_pre[3])
axes['O_State15'].set_yticks([0, 25, 50], ['-1', '0', '1'], fontsize = 12)
axes['O_State15'].set_xticks([], [])
axes['O_State15'].set_xlabel('pdf', fontsize = 12)
axes['O_State15'].set_ylabel('x', fontsize = 12)
axes["E_State15"].plot(torch.arange(0, rnn.t), Energy_15[traj, 0, :], alpha=0.9, color = pal_pre[3]) 
axes["E_State15"].plot(torch.arange(0, rnn.t), rnn.x_th_plus*torch.ones(rnn.t), linestyle='--', linewidth=1.5, color='red')
axes["E_State15"].set_xlabel('t', fontsize = 12)
axes["E_State15"].set_ylabel('E(t)', fontsize = 12) 
axes["E_State15"].set_xlim(0, 600)
axes["E_State15"].set_ylim(0.06, 0.14)
axes["State15"].set_xticks([0, 600],['0', '600'], fontsize = 12)
axes["State15"].set_yticks([-1, 0, 1],['-1','0','1'], fontsize = 12)
axes["E_State15"].set_xticks([0, 600],['0', '600'], fontsize = 12)
axes["E_State15"].set_yticks([0.06,0.11, 0.14],['0.06','0.11', '0.14'], fontsize = 12)

colors = [pal_pre[0], pal_pre[1], pal_pre[2], pal_pre[3], pal_pre[4]]
x_pos = [0, 1, 2, 3, 4]
betas = ['0.0', '0.2', '0.5', '0.8', '1.5']

data = [ np.mean(Var_single_00), np.mean(Var_single_02),  np.mean(Var_single_05), np.mean(Var_single_08), np.mean(Var_single_15)]
yerr = [ np.std(np.mean(Var_single_00, axis = 1))/np.sqrt(params['Naverage']), np.std(np.mean(Var_single_02, axis = 1))/np.sqrt(params['Naverage']), np.std(np.mean(Var_single_05, axis = 1))/np.sqrt(params['Naverage']), np.std(np.mean(Var_single_08, axis = 1))/np.sqrt(params['Naverage']), np.std(np.mean(Var_single_15, axis = 1))/np.sqrt(params['Naverage'])]
axes["STD_single"].bar(x_pos, data, yerr = yerr, color = colors)
axes["STD_single"].set_ylabel(r'$\langle \sigma \rangle$', fontsize = 12)
# axes["STD_single"].set_xticks(x_pos, [r'$\beta = 0$', r'$\beta = 0.2$', r'$\beta = 0.5$', r'$\beta = 0.8$', r'$\beta = 1.5$'], fontsize = 12)
axes["STD_single"].set_yticks([0,0.2,0.4], ['0','0.2', '0.4'], fontsize = 12)
axes["STD_single"].set_xticks(x_pos, betas, fontsize = 12)
axes["STD_single"].set_xlabel(r'$\beta$', fontsize = 12)

data = [ Delta_00, Delta_02, Delta_05, Delta_08, Delta_15]
yerr = [Delta_00_err, Delta_02_err, Delta_05_err, Delta_08_err, Delta_15_err]
axes["STD_delta"].bar(x_pos, data, yerr = yerr, color = colors)
axes["STD_delta"].set_ylabel(r'$\langle \sigma_{\Delta x} \rangle$', fontsize = 12)
# axes["STD_delta"].set_xticks(x_pos, [r'$\beta = 0$', r'$\beta = 0.2$',  r'$\beta = 0.5$', r'$\beta = 0.8$', r'$\beta = 1.5$'], fontsize = 12)
axes["STD_delta"].set_yticks([0,0.02,0.04], ['0', '0.02', '0.04'], fontsize = 12)
axes["STD_delta"].set_xticks(x_pos, betas, fontsize = 12)
axes["STD_delta"].set_xlabel(r'$\beta$', fontsize = 12)

data = [ np.mean(ED_00),  np.mean(ED_02),np.mean(ED_05), np.mean(ED_08), np.mean(ED_15)]
yerr = [ np.std(ED_00)/np.sqrt(params['Naverage']), np.std(ED_02)/np.sqrt(params['Naverage']), np.std(ED_05)/np.sqrt(params['Naverage']),  np.std(ED_08)/np.sqrt(params['Naverage']), np.std(ED_15)/np.sqrt(params['Naverage'])]
axes["ED_action"].bar(x_pos, data, yerr = yerr, color = colors)
axes["ED_action"].set_ylabel(r'$ED_a$', fontsize = 12)
axes["ED_action"].set_xlabel(r'$\beta$', fontsize = 12)
axes["ED_action"].set_yticks([0,5,8], ['0', '5', '8'], fontsize = 12)
# axes["ED_action"].set_xticks(x_pos, [r'$\beta = 0$', r'$\beta = 0.2$',  r'$\beta = 0.5$', r'$\beta = 0.8$', r'$\beta = 1.5$'], fontsize = 12)
axes["ED_action"].set_xticks(x_pos, betas, fontsize = 12)
axes["ED_action"].set_xlabel(r'$\beta$', fontsize = 12)

data = [ np.mean(T_end_00[:,-1], axis = 0),  np.mean(T_end_02[:,-1], axis = 0), np.mean(T_end_05[:,-1], axis = 0), np.mean(T_end_08[:,-1], axis = 0), np.mean(T_end_15[:,-1], axis = 0)]
yerr = [ np.std(T_end_00[:,-1], axis = 0)/np.sqrt(params['Naverage']),  np.std(T_end_02[:,-1], axis = 0)/np.sqrt(params['Naverage']), np.std(T_end_05[:,-1], axis = 0)/np.sqrt(params['Naverage']), np.std(T_end_08[:,-1], axis = 0)/np.sqrt(params['Naverage']), np.std(T_end_15[:,-1], axis = 0)/np.sqrt(params['Naverage'])]
axes["T_end"].bar(x_pos, data, yerr = yerr, color = colors)
axes["T_end"].set_ylabel(r'$t_{end}$', fontsize = 12)
axes["T_end"].set_yticks([0,400,600], ['0', '400', '600'], fontsize = 12)
axes["T_end"].set_xticks(x_pos, [ r'$\beta = 0$', r'$\beta = 0.2$', r'$\beta = 0.5$', r'$\beta = 0.8$', r'$\beta = 1.5$'], fontsize = 12)
axes["T_end"].set_xticks(x_pos, betas, fontsize = 12)
axes["T_end"].set_xlabel(r'$\beta$', fontsize = 12)

plt.show()
