# Importing the libraries 

import numpy as np
import dill
import matplotlib.pyplot as plt
import seaborn as sns

#%% VARIABLES TO SET

name_file = 'Figure3'

# If also non exponential decay files are generated, load also these values
name_file_eps = 'Figure3eps'
eps = [ 0.1, 0.01, 0.001, 0.0001, 0.00001]

#%%LOADING DATA

with open(name_file + '_training_pre.pkl', 'rb') as file:
    data = dill.load(file)

params, MSE_pre, STD_pre, STD_single_pre, T_end_pre, MSE_greedy_pre, STD_greedy_pre, STD_single_greedy_pre, T_end_greedy_pre, rnn, MOP, greedy = data

with open(name_file + '_activity.pkl', 'rb') as file:
    data_activity = dill.load(file)

Activity, Energy, H_mop, Activity_greedy, Energy_greedy, H_greedy, Activity_free, Energy_free = data_activity
    
T_end_epsilon = []

for epsilon in eps:
    
    with open(name_file_eps +'_'+  str(epsilon) +'_training_pre.pkl', 'rb') as file:
        data = dill.load(file)

    _, _, _, _, surv, _, _,  = data
    T_end_epsilon.append(surv)

#%% PLOT FEATURES

pal_MOP = '#C55986'
pal_MOP_2 = '#76003A'
pal_greedy = '#354a78ff'
pal_epsilon = ['#586B9C', '#7C8DC0', '#A0B1E7', '#C6D7FF' ] 
pal_grey = ['f8f9fa', 'e9ecef', 'dee2e6', 'ced4da', 'adb5bd', '6c757d', '495057', '343a40']
pal_grey = [f"#{c}" for c in pal_grey]

plt.rcParams['text.usetex'] = True
plt.rcParams['mathtext.fontset'] = 'stix'
plt.rcParams['font.family'] = 'STIXGeneral'

# set default style
sns.reset_defaults() # useful when adjusting style a lot
sns.set_theme(context="paper", style="ticks",
              # palette="Set2",
              palette=pal_grey,
              rc={
              "pdf.fonttype": 42,  # embed font in output
              "svg.fonttype": "none",  # embed font in output
              "figure.facecolor": "white",
              "figure.dpi": 100,
              "axes.facecolor": "None",
              "axes.spines.left": True,
              "axes.spines.bottom": True,
              "axes.spines.right": False,
              "axes.spines.top": False,
          },
          )

sns.color_palette(pal_grey) 

#%%AND OCCUPANCY
N_bins = 50
min_bins = -1
max_bins = 1
delta_bins = (max_bins - min_bins)/N_bins

Occupancy_free = np.zeros((N_bins, rnn.t))
Occupancy_MOP = np.zeros((N_bins, rnn.t))
Occupancy_greedy = np.zeros((N_bins, rnn.t))

for tt in range(rnn.t):
    for bb in range(N_bins):
        for nn in range(rnn.N):
            if min_bins + bb * delta_bins < Activity_free[nn, tt] < min_bins + (bb+1) * delta_bins:
                Occupancy_free[bb, tt] += 1
            if min_bins + bb * delta_bins < Activity[nn, tt] < min_bins + (bb+1) * delta_bins:
                Occupancy_MOP[bb, tt] += 1
            if min_bins + bb * delta_bins < Activity_greedy[nn, tt] < min_bins + (bb+1) * delta_bins:
                Occupancy_greedy[bb, tt] += 1


#%%

                 

fig, axes = plt.subplot_mosaic([
    ["no", "A_MOP", "A_MOP", "A_MOP", "ctx_MOP", "ctx_MOP_zoom", "colorbar"],
    ["T_end","T_end", "T_end_decay", "T_end_decay", "ctx_greedy", "ctx_greedy_zoom", "colorbar2"],
    ], gridspec_kw=dict(hspace=0.8, 
                      wspace=1.2,
                      width_ratios=[0.5,0.1,0.1,0.5,1, 1, 0.1]),)


# from matplotlib.colors import ListedColormap
from matplotlib.cm import magma
from matplotlib.colors import Normalize

# Use magma colormap to map continuous values, max is given by the maximum entropy
norm = Normalize(vmin=np.min([np.min(H_mop.numpy())]), vmax=np.log(2**params['Nc']))
# norm = Normalize(vmin=np.min([np.min(H_mop.numpy()), H_greedy[0]]), vmax=np.log(2**params['Nc']))

colors = magma(norm(H_mop)).reshape(-1, 4)
# axes["ctx_MOP"].scatter(Activity[2,:], Activity[3,:],  color = pal_grey[2], s = 2, label = r'$x_3, x_4$', alpha = 0.6)
axes["ctx_MOP"].plot(Activity[2,:], Activity[3,:],  color = pal_grey[2], label = r'$x_3, x_4$', alpha = 0.6)
axes["ctx_MOP"].plot(Activity[4,:], Activity[5,:],  color = pal_grey[2], label = r'$x_3, x_4$', alpha = 0.6)
axes["ctx_MOP"].plot(Activity[6,:], Activity[7,:],  color = pal_grey[2], label = r'$x_3, x_4$', alpha = 0.6)
axes["ctx_MOP"].set_xlabel(r'$x_1$', fontsize = 12)
axes["ctx_MOP"].set_ylabel(r'$x_2$', fontsize = 12)
axes["ctx_MOP"].set_xticks([-1,0,1], ['-1', '0', '1'], fontsize = 12)
axes["ctx_MOP"].set_yticks([-1,0,1], ['-1', '0', '1'], fontsize = 12)

# Draw lines between points with color-dependent colors
for i in range(len(Activity[0,:]) - 1):
    axes["ctx_MOP"].plot(Activity[0,i:i+2], Activity[1,i:i+2], c=colors[i], linewidth = 1)
    axes["ctx_MOP_zoom"].plot(Activity[0,i:i+2], Activity[1,i:i+2], c=colors[i], linewidth = 1)
    axes["ctx_MOP_zoom"].plot(Activity[2,i:i+2], Activity[3,i:i+2], c=pal_grey[4], linewidth = 0.6, alpha = 0.5)
    axes["ctx_MOP_zoom"].plot(Activity[4,i:i+2], Activity[5,i:i+2], c=pal_grey[4], linewidth = 0.6, alpha = 0.5)
axes["ctx_MOP_zoom"].set_xlabel(r'$x_1$', fontsize = 12)
axes["ctx_MOP_zoom"].set_ylabel(r'$x_2$', fontsize = 12)
axes["ctx_MOP_zoom"].set_xticks([-0.2,0,0.2], ['-0.2', '0', '0.2'], fontsize = 12)
axes["ctx_MOP_zoom"].set_yticks([-0.2,0,0.2], ['-0.2', '0', '0.2'], fontsize = 12)
axes["ctx_MOP_zoom"].set_xlim([-0.2,0.2])
axes["ctx_MOP_zoom"].set_ylim([-0.2,0.2])

axes["A_MOP"].plot(np.arange(0, rnn.t), Activity[3:20,:].T, linestyle='-', markersize = 2)
axes["A_MOP"].plot(np.arange(0, rnn.t), Activity[2,:].T, linestyle='-', color = pal_grey[6], markersize = 2, label = r"$x_i$ $\forall i \ne 1,2$ ")
axes["A_MOP"].plot(np.arange(0, rnn.t), Activity[0,:].T,  linestyle='-', color = pal_MOP, markersize = 2, label = r"$x_1, x_2$")
axes["A_MOP"].plot(np.arange(0, rnn.t), Activity[1,:].T,  linestyle='-', color = pal_MOP, markersize = 2)
axes["A_MOP"].set_ylabel('x(t)', fontsize = 12)
axes["A_MOP"].set_xlabel('t', fontsize = 12)
axes["A_MOP"].set_xticks([0, rnn.t], ['0', '1000'], fontsize = 12)
axes["A_MOP"].set_yticks([-1,0,1], ['-1', '0', '1'], fontsize = 12)
axes["A_MOP"].legend(loc=(0.05, 0.9), frameon=False)

colors = magma(norm(H_greedy)).reshape(-1, 4)
axes["ctx_greedy"].plot(Activity_greedy[2,:], Activity_greedy[3,:],  color = pal_grey[3], label = r'$x_3, x_4$', alpha = 0.6)
axes["ctx_greedy"].plot(Activity_greedy[4,:], Activity_greedy[5,:],  color = pal_grey[3], label = r'$x_5, x_6$', alpha = 0.6)
axes["ctx_greedy"].plot(Activity_greedy[6,:], Activity_greedy[7,:],  color = pal_grey[3], label = r'$x_7, x_8$', alpha = 0.6)
axes["ctx_greedy"].set_xlabel(r'$x_1$', fontsize = 12)
axes["ctx_greedy"].set_ylabel(r'$x_2$', fontsize = 12)
axes["ctx_greedy"].set_xticks([-1,0,1], ['-1', '0', '1'], fontsize = 12)
axes["ctx_greedy"].set_yticks([-1,0,1], ['-1', '0', '1'], fontsize = 12)

# Draw lines between points with color-dependent colors
for i in range(len(Activity_greedy[0,:]) - 1):
    axes["ctx_greedy"].plot(Activity_greedy[0,i:i+2], Activity_greedy[1,i:i+2], c=colors[i], linewidth = 1)
    axes["ctx_greedy_zoom"].plot(Activity_greedy[0,i:i+2], Activity_greedy[1,i:i+2], c=colors[i], linewidth = 1)
    axes["ctx_greedy_zoom"].plot(Activity_greedy[2,i:i+2], Activity_greedy[3,i:i+2], c=pal_grey[4], linewidth = 0.6, alpha = 0.5)
    axes["ctx_greedy_zoom"].plot(Activity_greedy[4,i:i+2], Activity_greedy[5,i:i+2], c=pal_grey[4], linewidth = 0.6, alpha = 0.5)
    axes["ctx_greedy_zoom"].plot(Activity_greedy[6,i:i+2], Activity_greedy[7,i:i+2], c=pal_grey[4], linewidth = 0.6, alpha = 0.5)
axes["ctx_greedy_zoom"].set_xlabel(r'$x_1$', fontsize = 12)
axes["ctx_greedy_zoom"].set_ylabel(r'$x_2$', fontsize = 12)
axes["ctx_greedy_zoom"].set_xticks([-0.2,0,0.2], ['-0.2', '0', '0.2'], fontsize = 12)
axes["ctx_greedy_zoom"].set_yticks([-0.2,0,0.2], ['-0.2', '0', '0.2'], fontsize = 12)
axes["ctx_greedy_zoom"].set_xlim([-0.2,0.2])
axes["ctx_greedy_zoom"].set_ylim([-0.2,0.2])


axes['T_end_decay'].plot(np.mean(T_end_pre, 0), alpha=0.9, color = pal_MOP, label = 'MOP') 
axes['T_end_decay'].fill_between(x = np.arange(params['Nepochs']), y1 = np.mean(T_end_pre, 0) - np.std(T_end_pre, 0)/np.sqrt(params['Naverage']), y2 = np.mean(T_end_pre, 0) + np.std(T_end_pre, 0)/np.sqrt(params['Naverage']), alpha = 0.2, color = pal_MOP)
axes['T_end_decay'].plot(np.mean(T_end_greedy_pre, 0), alpha=0.9, color = pal_greedy, label = r'R ($\epsilon$- decay)') 
axes['T_end_decay'].fill_between(x = np.arange(params['Nepochs']), y1 = np.mean(T_end_greedy_pre, 0) - np.std(T_end_greedy_pre, 0)/np.sqrt(params['Naverage']), y2 = np.mean(T_end_greedy_pre, 0) + np.std(T_end_greedy_pre, 0)/np.sqrt(params['Naverage']), alpha = 0.2, color = pal_greedy)
axes['T_end_decay'].set_xlabel('epochs', fontsize = 12)
axes['T_end_decay'].set_ylabel('$t_{end}$', fontsize = 12) 
axes["T_end_decay"].set_yticks([0, 600, rnn.t], ['0','600', '1000'], fontsize = 12)
axes["T_end_decay"].set_xticks([0, params['Nepochs']], ['0', '100'], fontsize = 12)
# axes['T_end_decay'].set_title('Exponential decay',fontsize = 12 )
axes['T_end_decay'].legend(loc='upper left', frameon=False)


eps_reversed = eps[::-1]
for idx, epsilon in enumerate(eps_reversed):
    surv = T_end_epsilon[idx]
    axes['T_end'].errorbar(x= idx+1, y = np.mean(surv, 0)[-1], yerr= np.std(surv, 0)[-1]/np.sqrt(params['Naverage']), fmt = '.', alpha=0.9, color = pal_greedy, label = r'$\epsilon = $' + str(epsilon)) 
axes['T_end'].set_xlabel(r'$\epsilon$', fontsize = 12)
axes['T_end'].set_ylabel('$t_{end}$', fontsize = 12) 
axes["T_end"].set_yticks([0, 600, rnn.t], ['0','600', '1000'], fontsize = 12)
axes["T_end"].set_xticks([i+1 for i in range(len(eps))])
axes["T_end"].set_xticklabels([r'$10^{-1}$',r'$10^{-2}$',r'$10^{-3}$',r'$10^{-4}$ ',r'$10^{-5}$'], fontsize = 11)

# Adding colorbar using magma
sm = plt.cm.ScalarMappable(cmap=magma, norm=norm)
sm.set_array([])  # empty array for the scalar mappable
cbar = fig.colorbar(sm, axes['colorbar'])
cbar.set_label(r'$\mathcal{H}(\mathcal{A}|x)$',  fontsize = 12)
axes['colorbar'].set_yticks([0.00,5.54] ,['0.00','5.54'], fontsize = 12)
cbar.outline.set_edgecolor('none')

sm = plt.cm.ScalarMappable(cmap=magma, norm=norm)
sm.set_array([])  # empty array for the scalar mappable
cbar = fig.colorbar(sm, axes['colorbar2'])
cbar.set_label(r'$\mathcal{H}(\mathcal{A}|x)$',  fontsize = 12)
axes['colorbar2'].set_yticks([0.00, 5.54] ,['0.00','5.54'], fontsize = 12)
cbar.outline.set_edgecolor('none')

axes['no'].tick_params(axis='both', which='both', bottom=False, left=False, labelbottom=False, labelleft=False)
axes['no'].spines['top'].set_visible(False)
axes['no'].spines['right'].set_visible(False)
axes['no'].spines['bottom'].set_visible(False)
axes['no'].spines['left'].set_visible(False)

fig = plt.gcf()  # Get the current figure
fig.set_size_inches(12, 6)  # Set the size in inches
plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.2)  # Adjust margins
plt.show()

