#%% IMPORTING THE LIBRARIES

import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import dill
from scipy.stats import ttest_ind


#%% PLOT FEATURES

pal_grey = ['dee2e6', 'ced4da', 'adb5bd', '6c757d', '495057', '343a40']
pal_grey = [f"#{c}" for c in pal_grey]
pal_MOP = '#c55986ff'
pal_greedy = '#354a78ff'

plt.rcParams['text.usetex'] = True
plt.rcParams['mathtext.fontset'] = 'stix'
plt.rcParams['font.family'] = 'STIXGeneral'

# set default style
sns.reset_defaults() # useful when adjusting style a lot
sns.set_theme(context="paper", style="ticks",
              # palette="Set2",
              palette=pal_grey,
              rc={
              "pdf.fonttype": 42,  # embed font in output
              "svg.fonttype": "none",  # embed font in output
              "figure.facecolor": "white",
              "figure.dpi": 100,
              "axes.facecolor": "None",
              "axes.spines.left": True,
              "axes.spines.bottom": True,
              "axes.spines.right": False,
              "axes.spines.top": False,
          },
          )

sns.color_palette(pal_pink) 

#%%LOAD DATA

name_file = 'Figure2'

with open(name_file + '_training.pkl', 'rb') as file:
    data = dill.load(file)

params, MSE, STD, STD_single, T_end, MSE_greedy, STD_greedy, STD_single_greedy, T_end_greedy, rnn, MOP, greedy = data

with open(name_file + '_activity.pkl', 'rb') as file:
    data_activity = dill.load(file)

ED_MOP_far, ED_MOP_thresh, ED_greedy_far, ED_greedy_thresh, Activity_MOP, Energy_MOP, H_mop, Activity_greedy, Energy_greedy, H_greedy, Activity_free, Energy_free, Activity_train, H_train = data_activity
    
Energy_train = np.linalg.norm( ( Activity_train + 1 ), axis = 1 )/ Activity_train.shape[1]

#%%AND OCCUPANCY
N_bins = 50
min_bins = -1
max_bins = 1
delta_bins = (max_bins - min_bins)/N_bins

Occupancy_free = np.zeros((N_bins, rnn.t))
Occupancy_MOP = np.zeros((N_bins, rnn.t))
Occupancy_greedy = np.zeros((N_bins, rnn.t))

for ag in range(params['Naverage']):
    for tt in range(rnn.t):
        for bb in range(N_bins):
            for nn in range(rnn.N):
                if min_bins + bb * delta_bins < Activity_free[ag, nn, tt] < min_bins + (bb+1) * delta_bins:
                    Occupancy_free[bb, tt] += 1
                if min_bins + bb * delta_bins < Activity_MOP[ag, nn, tt] < min_bins + (bb+1) * delta_bins:
                    Occupancy_MOP[bb, tt] += 1
                if min_bins + bb * delta_bins < Activity_greedy[ag, nn, tt] < min_bins + (bb+1) * delta_bins:
                    Occupancy_greedy[bb, tt] += 1


#%% COLORMAP FOR THE ENERGIES

from matplotlib.colors import ListedColormap
from matplotlib.cm import magma, inferno, ocean, turbo, viridis, plasma
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.colors import Normalize


# Use colormap to map continuous values, max is given by the maximum entropy, 
# minimum is given by the greedy agent, as MOP is always greater except one minimum value of entropy (0) when entering terminal states
norm = Normalize(vmin=np.min(H_greedy), vmax=np.log(2**params['Nc']))

# Choosing one random trajectory to plot
traj = 1

# Defining the mosaic of Figure 2
fig, axes = plt.subplot_mosaic([
    ["Free", "Train", "MOP", "Greedy", 'no'],
    ["Energy_free", "Energy_train", "Energy_MOP", "Energy_greedy", 'colorbar'],
    ["pdf" , "ED", "STD_single", "Tend", 'no2']
    ], gridspec_kw=dict(hspace=0.5, 
                      wspace=0.75,
                      width_ratios=[1,1,1,1,0.1]),)

# axes['Free'].set_title('Free network', fontsize  = 12)
axes['Free'].plot(torch.arange(0, params['TotalT']/params['dt']), Activity_free[traj, :20, :].T, alpha=0.9) 
axes['Free'].set_xlim(0, params['TotalT']/params['dt'])
axes['Free'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['Free'].set_yticks([-1,0,1], ['-1','0','1'], fontsize  = 12)
axes['Free'].set_xlabel('t', fontsize  = 12)
axes['Free'].set_ylabel('x(t)', fontsize  = 12) 
axes['Free'].set_ylim(-1,1)


axes['Energy_free'].plot(torch.arange(0, params['TotalT']/params['dt']), Energy_free[traj,0],linewidth=2,  color = pal_grey[3]) 
axes['Energy_free'].plot(torch.arange(0, params['TotalT']/params['dt']), rnn.x_th_plus*torch.ones(rnn.t), linestyle='--', linewidth=1.5, color='#98aaac')
axes['Energy_free'].set_xlim(0, params['TotalT']/params['dt'])
axes['Energy_free'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['Energy_free'].set_xlabel('t', fontsize  = 12)
axes['Energy_free'].set_ylabel('E(t)', fontsize  = 12) 
axes['Energy_free'].set_ylim([0.03,0.15])
axes['Energy_free'].set_yticks([0.03, 0.11, 0.15], ['0.03', '0.11', '0.15'], fontsize = 12)

axes['Energy_free'].tick_params(axis='both', which='both', labelsize  = 12)
axes['Free'].tick_params(axis='both', which='both', labelsize  = 12)

t_end_train = np.where(Activity_train[traj, 0,:] == 0)[0][0]
axes['Train'].plot(torch.arange(0, t_end_train), Activity_train[traj,:15, :t_end_train].T, alpha=0.9) 
axes['Train'].set_xlim(0, params['TotalT']/params['dt'])
axes['Train'].set_yticks([-1,0,1], ['-1','0','1'], fontsize  = 12)
axes['Train'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['Train'].set_xlabel('t', fontsize  = 12)
axes['Train'].set_ylabel('x(t)', fontsize  = 12) 
axes['Train'].set_ylim(-1,1)

axes['Energy_train'].plot(torch.arange(0, params['TotalT']/params['dt']), rnn.x_th_plus*torch.ones(rnn.t), linestyle='--', linewidth=1.5, color='#98aaac')
axes['Energy_train'].set_xlim(0, params['TotalT']/params['dt'])
axes['Energy_train'].set_xlabel('t', fontsize  = 12)
axes['Energy_train'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['Energy_train'].set_ylabel('E(t)', fontsize  = 12) 
axes['Energy_train'].set_ylim([0.03,0.15])
axes['Energy_train'].tick_params(axis='both', which='both', labelsize  = 12)
axes['Energy_train'].set_yticks([0.03, 0.11, 0.15], ['0.03', '0.11', '0.15'], fontsize = 12)
axes['Train'].tick_params(axis='both', which='both', labelsize  = 12)

# Plotting with continuous colormap
colors = magma(norm(H_train[traj, 0, :t_end_train])).reshape(-1, 4)
inset_ax = axes['Energy_train'].inset_axes([0.62, 0.18,0.4,0.34])
inset_ax.plot(torch.arange(0, params['TotalT']/params['dt']), rnn.x_th_plus*torch.ones(rnn.t), linestyle='--', linewidth=1, color='#98aaac')
inset_ax.set_ylim(0.10, 0.112)  # Set smaller y-axis limit for the inset
inset_ax.set_xticks([0,1000], ['0','1000'], fontsize  = 12)
inset_ax.set_yticks([0.10, 0.110], ['0.10', '0.11'], fontsize = 12)
# Draw lines between points with color-dependent colors
for i in range(t_end_train - 1):
    axes['Energy_train'].plot(range(t_end_train)[i:i+2], Energy_train[traj, i:i+2], c=colors[i], linewidth = 1.5)
    inset_ax.plot(range(t_end_train)[i:i+2], Energy_train[traj, i:i+2], c=colors[i], linewidth = 0.5)
# Plotting the last point in the dynamics that is the terminal state as a point (not line) to show the zero entropy
inset_ax.scatter(range(t_end_train)[-1],  Energy_train[traj, t_end_train-1:t_end_train], c=colors[-1], s=10, marker='.', zorder = 2)
axes['Energy_train'].scatter(range(t_end_train)[-1],  Energy_train[traj, t_end_train-1:t_end_train], c=colors[-1], s=10, marker='.', zorder = 2)


axes['MOP'].plot(torch.arange(0, params['TotalT']/params['dt']), Activity_MOP[traj, :15, :].T, alpha=0.9) 
axes['MOP'].set_xlim(0, params['TotalT']/params['dt'])
axes['MOP'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['MOP'].set_yticks([-1,0,1], ['-1','0','1'], fontsize  = 12)
axes['MOP'].set_xlabel('t', fontsize  = 12)
axes['MOP'].set_ylabel('x(t)', fontsize  = 12) 
axes['MOP'].set_ylim(-1,1)
axes['MOP'].set_yticks([-1,0,1], ['-1','0','1'], fontsize  = 12)


axes['Energy_MOP'].plot(torch.arange(0, params['TotalT']/params['dt']), rnn.x_th_plus*torch.ones(rnn.t), linestyle='--', linewidth=1.5, color='#98aaac')
axes['Energy_MOP'].set_xlim(0, params['TotalT']/params['dt'])
axes['Energy_MOP'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['Energy_MOP'].set_xlabel('t', fontsize  = 12)
axes['Energy_MOP'].set_ylabel('E(t)', fontsize  = 12) 
axes['Energy_MOP'].set_ylim([0.03,0.15])
axes['Energy_MOP'].set_yticks([0.03, 0.11, 0.15], ['0.03', '0.11', '0.15'], fontsize = 12)
colors = magma(norm(H_mop[traj,0])).reshape(-1, 4)

inset_ax = axes['Energy_MOP'].inset_axes([0.62, 0.18,0.4,0.34])
for i in range(rnn.t - 1):
    axes['Energy_MOP'].plot(range(rnn.t)[i:i+2], Energy_MOP[traj, 0, i:i+2], c=colors[i], linewidth = 1.5)
    inset_ax.plot(range(rnn.t)[i:i+2], Energy_MOP[traj, 0, i:i+2], c=colors[i], linewidth = 0.5)

inset_ax.plot(torch.arange(0, params['TotalT']/params['dt']), rnn.x_th_plus*torch.ones(rnn.t), linestyle='--', linewidth=1, color='#98aaac')
inset_ax.set_ylim(0.10, 0.112)  # Set smaller y-axis limit for the inset
inset_ax.set_xticks([0,rnn.t], ['0','1000'], fontsize  = 12)
inset_ax.set_yticks([0.10, 0.110], ['0.10', '0.11'], fontsize = 12)


axes['MOP'].tick_params(axis='both', which='both', labelsize  = 12)
axes['Energy_MOP'].tick_params(axis='both', which='both', labelsize  = 12)

axes['Greedy'].plot(torch.arange(0, params['TotalT']/params['dt']), Activity_greedy[traj, :15, :].T, alpha=0.9) 
axes['Greedy'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['Greedy'].set_yticks([-1,0,1], ['-1','0','1'], fontsize  = 12)
axes['Greedy'].set_xlim(0, params['TotalT']/params['dt'])
axes['Greedy'].set_xlabel('t', fontsize  = 12)
axes['Greedy'].set_ylabel('x(t)', fontsize  = 12) 
axes['Greedy'].set_ylim(-1,1)

axes['Energy_greedy'].plot(torch.arange(0, params['TotalT']/params['dt']), rnn.x_th_plus*torch.ones(rnn.t), linestyle='--', linewidth=1.5, color='#98aaac')
axes['Energy_greedy'].set_xlim(0, params['TotalT']/params['dt'])
axes['Energy_greedy'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['Energy_greedy'].set_yticks([0.03, 0.11, 0.15], ['0.03', '0.11', '0.15'], fontsize = 12)
axes['Energy_greedy'].set_xlabel('t', fontsize  = 12)
axes['Energy_greedy'].set_ylabel('E(t)', fontsize  = 12) 
axes['Energy_greedy'].set_ylim([0.03,0.15])
colors = magma(norm(H_greedy[traj,0])).reshape(-1, 4)

for i in range(rnn.t - 1):
    axes['Energy_greedy'].plot(range(1000)[i:i+2], Energy_greedy[traj,0,i:i+2], linewidth = 1, c=colors[i],linestyle='-')

axes['Energy_greedy'].tick_params(axis='both', which='both', labelsize  = 12)
axes['Greedy'].tick_params(axis='both', which='both', labelsize=12)

axes['Tend'].plot(np.mean(T_end_greedy, 0), alpha=0.9, color = pal_greedy, label = 'R') 
axes['Tend'].fill_between(x = np.arange(params['Nepochs']), y1 = np.mean(T_end_greedy, 0) - np.std(T_end_greedy, 0)/np.sqrt(params['Naverage']), y2 = np.mean(T_end_greedy, 0) + np.std(T_end_greedy, 0)/np.sqrt(params['Naverage']), alpha = 0.2, color = pal_greedy)
axes['Tend'].plot(np.mean(T_end, 0), alpha=0.9, color = pal_MOP, label = 'MOP') 
axes['Tend'].fill_between(x = np.arange(params['Nepochs']), y1 = np.mean(T_end, 0) - np.std(T_end, 0)/np.sqrt(params['Naverage']), y2 = np.mean(T_end, 0) + np.std(T_end, 0)/np.sqrt(params['Naverage']), alpha = 0.2, color = pal_MOP)
axes['Tend'].set_xlabel('epochs', fontsize  = 12)
axes['Tend'].set_ylabel('$t_{end}$', fontsize  = 12) 
axes['Tend'].legend(frameon=False)
axes['Tend'].tick_params(axis='both', which='both', labelsize  = 12)

axes['pdf'].barh(np.arange(N_bins), np.mean(Occupancy_free, axis=1) / rnn.N, height=1, facecolor = 'none', edgecolor=pal_grey[2], label = 'free')
axes['pdf'].barh(np.arange(N_bins), np.mean(Occupancy_greedy, axis=1) / rnn.N, height=1, facecolor = 'none', edgecolor=pal_greedy, label = r'$\epsilon$-greedy')
axes['pdf'].barh(np.arange(N_bins), np.mean(Occupancy_MOP, axis=1) / rnn.N, height=1, facecolor = 'none',  edgecolor=pal_MOP, label = 'MOP')
axes['pdf'].set_yticks([0, 24, 49], ['-1', '0', '1'], fontsize = 12)
axes['pdf'].set_xticks([])
axes['pdf'].tick_params(axis='both', which='both', labelsize  = 12)
axes['pdf'].set_ylabel('x', fontsize  = 12)
axes['pdf'].set_xlabel('pdf', fontsize  = 12)
axes['pdf'].legend()
                
axes['STD_single'].plot(np.mean(STD_single_greedy,0), alpha=0.9, color = pal_greedy, label = 'R') 
axes['STD_single'].fill_between(x = np.arange(params['Nepochs']), y1 = np.mean(STD_single_greedy, 0) - np.std(STD_single_greedy, 0)/np.sqrt(params['Naverage']), y2 = np.mean(STD_single_greedy, 0) + np.std(STD_single_greedy, 0)/np.sqrt(params['Naverage']), alpha = 0.2, color = pal_greedy) 
axes['STD_single'].plot(np.mean(STD_single,0), alpha=0.9, color = pal_MOP, label = 'MOP') 
axes['STD_single'].fill_between(x = np.arange(params['Nepochs']), y1 = np.mean(STD_single, 0) - np.std(STD_single, 0)/np.sqrt(params['Naverage']), y2 = np.mean(STD_single, 0) + np.std(STD_single, 0)/np.sqrt(params['Naverage']), alpha = 0.2, color = pal_MOP) 
axes['STD_single'].set_xlabel('epochs', fontsize  = 12)
axes['STD_single'].set_ylabel(r'$\langle\sigma\rangle$', fontsize  = 12) 
axes['STD_single'].legend(frameon=False)
axes['STD_single'].tick_params(axis='both', which='both', labelsize  = 12)

#Error is SEM, averaged across agents, the greedy R agent never approaches the threshold and it would always be zero as that region of state space is not visited
x_pos = [0, 1, 2]#, 3]
data = [np.mean(ED_MOP_far), np.mean(ED_MOP_thresh), np.mean(ED_greedy_far)]
yerr = [np.std(ED_MOP_far)/np.sqrt(params['Naverage']), np.std(ED_MOP_thresh)/np.sqrt(params['Naverage']), np.std(ED_greedy_far)/np.sqrt(params['Naverage'])]
axes['ED'].bar(x_pos, data, yerr = yerr, align='center', width=0.8, color = [pal_MOP, pal_MOP, pal_greedy, pal_greedy])
axes['ED'].set_yticks(data)
axes['ED'].set_xticks(x_pos, [r'$E(x)\ll L$', r'$E(x)\sim L$', r'$E(x)\ll L$'], fontsize = 9)
axes['ED'].set_ylabel('$ED_{a}$', fontsize  = 12) 
axes['ED'].set_xlabel('state space regions', fontsize  = 12)
axes['ED'].tick_params(axis='y', which='both', labelsize  = 12)
axes['ED'].yaxis.set_major_formatter('{:.0f}'.format)

# Adding colorbar using magma
sm = plt.cm.ScalarMappable(cmap=magma, norm=norm)
sm.set_array([])  # empty array for the scalar mappable
cbar = fig.colorbar(sm, axes['colorbar'])
cbar.set_label(r'$\mathcal{H}(\mathcal{A}|x)$', fontsize = 12)
axes['colorbar'].set_yticks([2.27,5.54] ,['<2.27','5.54'], fontsize = 12)
cbar.outline.set_edgecolor('none')

fig = plt.gcf()  # Get the current figure
fig.set_size_inches(25, 15)  # Set the size in inches
plt.show()
