#%% SETTING THE PARAMETERS

import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import dill

#%% PLOT FEATURES
pal_grey = ['dee2e6', 'ced4da', 'adb5bd', '6c757d', '495057', '343a40']
pal_grey = [f"#{c}" for c in pal_grey]
pal_MOP = '#c55986ff'
pal_greedy = '#354a78ff'

plt.rcParams['text.usetex'] = True
plt.rcParams['mathtext.fontset'] = 'stix'
plt.rcParams['font.family'] = 'STIXGeneral'

# set default style
sns.reset_defaults() # useful when adjusting style a lot
sns.set_theme(context="paper", style="ticks",
              # palette="Set2",
              palette=pal_grey,
              rc={
              "pdf.fonttype": 42,  # embed font in output
              "svg.fonttype": "none",  # embed font in output
              "figure.facecolor": "white",
              "figure.dpi": 100,
              "axes.facecolor": "None",
              "axes.spines.left": True,
              "axes.spines.bottom": True,
              "axes.spines.right": False,
              "axes.spines.top": False,
          },
          )

sns.color_palette(pal_pink) 

#%%LOAD DATA

name_file = 'Figure7'

with open(name_file + '_training.pkl', 'rb') as file:
    data = dill.load(file)

params, MSE, STD, STD_single, T_end, rnn, MOP = data

with open(name_file + '_activity.pkl', 'rb') as file:
    data_activity = dill.load(file)

ED_MOP_far, ED_MOP_thresh, Activity_MOP, Energy_MOP, H_mop, Activity_free, Energy_free, Activity_train, H_train = data_activity
    
#%%AND OCCUPANCY
N_bins = 50
min_bins = 0
max_bins = 2 #Here we choose an arbitrary maximum value, the Free network overshoots quickly this value
delta_bins = (max_bins - min_bins)/N_bins

Occupancy_free = np.zeros((N_bins, rnn.t))
Occupancy_MOP = np.zeros((N_bins, rnn.t))

for ag in range(params['Naverage']):
    for tt in range(rnn.t):
        for bb in range(N_bins):
            for nn in range(rnn.N):
                if min_bins + bb * delta_bins < Activity_free[ag, nn, tt] < min_bins + (bb+1) * delta_bins:
                    Occupancy_free[bb, tt] += 1
                if min_bins + bb * delta_bins < Activity_MOP[ag, nn, tt] < min_bins + (bb+1) * delta_bins:
                    Occupancy_MOP[bb, tt] += 1

#%% COLORMAP FOR THE ENERGIES

from matplotlib.colors import ListedColormap
from matplotlib.cm import magma, inferno, ocean, turbo, viridis, plasma
from matplotlib.colors import Normalize

# Use magma colormap to map continuous values, max is given by the maximum entropy
norm = Normalize(vmin=(np.min(H_mop)), vmax=np.log(2**params['Nc']))

# Choose one random trajectory to plot
traj = 0

fig, axes = plt.subplot_mosaic([
    ["Free", "MOP", "Energy_MOP", 'colorbar', "ED", "Tend"],
    ], gridspec_kw=dict(hspace=0.5, 
                      wspace=1,
                      width_ratios=[1,1,1,0.1,0.8,1]),)

# axes['Free'].set_title('Free network', fontsize  = 12)
axes['Free'].plot(torch.arange(0, params['TotalT']/params['dt']), Activity_free[traj, :20, :].T, alpha=0.9) 
axes['Free'].set_xlim(0, params['TotalT']/params['dt'])
axes['Free'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['Free'].set_yticks([0,2], ['0','2'], fontsize  = 12)
axes['Free'].set_xlabel('t', fontsize  = 12)
axes['Free'].set_ylabel('$x(t)$', fontsize  = 12) 
axes['Free'].set_ylim(0,2)

# axes['Energy_free'].tick_params(axis='both', which='both', labelsize  = 12)
axes['Free'].tick_params(axis='both', which='both', labelsize  = 12)


# axes['MOP'].set_title('MOP network', fontsize  = 12)
axes['MOP'].plot(torch.arange(0, params['TotalT']/params['dt']), Activity_MOP[traj, :15, :].T, alpha=0.9) 
axes['MOP'].set_xlim(0, params['TotalT']/params['dt'])
axes['MOP'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['MOP'].set_yticks([0,2], ['0','2'], fontsize  = 12)
axes['MOP'].set_xlabel('t', fontsize  = 12)
axes['MOP'].set_ylabel('$x(t)$', fontsize  = 12) 
axes['MOP'].set_ylim(0,2)


# axes['Energy_MOP'].plot(torch.arange(0, params['TotalT']/params['dt']), Energy.detach().numpy(), alpha=0.9, color = pal_MOP) 
axes['Energy_MOP'].plot(torch.arange(0, params['TotalT']/params['dt']), rnn.x_th_plus*torch.ones(rnn.t), linestyle='--', linewidth=1.5, color='#98aaac')
axes['Energy_MOP'].set_xlim(0, params['TotalT']/params['dt'])
axes['Energy_MOP'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['Energy_MOP'].set_xlabel('t', fontsize  = 12)
axes['Energy_MOP'].set_ylabel('E(t)', fontsize  = 12) 
axes['Energy_MOP'].set_ylim([0.10,0.15])
axes['Energy_MOP'].set_yticks([0.10, 0.13, 0.15], ['0.10', '0.13', '0.15'], fontsize = 12)
colors = magma(norm(H_mop[traj,0])).reshape(-1, 4)
for i in range(len(Energy_MOP[traj,0,:]) - 1):
    axes["Energy_MOP"].plot(range(1000)[i:i+2], Energy_MOP[traj,0,i:i+2], c=colors[i], linewidth = 0.6,)

axes['MOP'].set_xticks([0,1000], ['0', '1000'])
axes['Free'].set_xticks([0,1000], ['0', '1000'])
axes['Energy_MOP'].set_xticks([0,1000], ['0', '1000'])

axes['MOP'].tick_params(axis='both', which='both', labelsize  = 12)
axes['Energy_MOP'].tick_params(axis='both', which='both', labelsize  = 12)

axes['Tend'].plot(np.mean(T_end, 0), alpha=0.9, color = pal_MOP) 
axes['Tend'].fill_between(x = np.arange(params['Nepochs']), y1 = np.mean(T_end, 0) - np.std(T_end, 0)/np.sqrt(params['Naverage']), y2 = np.mean(T_end, 0) + np.std(T_end, 0)/np.sqrt(params['Naverage']), alpha = 0.2, color = pal_MOP)
axes['Tend'].set_xlabel('epochs', fontsize  = 12)
axes['Tend'].set_ylabel('$t_{end}$', fontsize  = 12) 
axes['Tend'].set_xticks([0,40], ['0', '40'])
axes['Tend'].set_yticks([0,1000], ['0', '1000'])
axes['Tend'].tick_params(axis='both', which='both', labelsize  = 12)
# axes['pdf'].tick_params(axis='both', which='both', labelsize  = 12)


#Error is SEM
x_pos = [0, 1 ] 
data = [np.mean(ED_MOP_far), np.mean(ED_MOP_thresh)]
yerr = [np.std(ED_MOP_far)/np.sqrt(params['Naverage']), np.std(ED_MOP_thresh)/np.sqrt(params['Naverage'])]
# draw_pvalue(0, 1, p_value, x_pos, data, axes['ED'] )
axes['ED'].bar(x_pos, data, yerr = yerr, align='center', width=0.8, color = [pal_MOP, pal_MOP])
axes['ED'].set_yticks([data[0]], ['6.87'])
axes['ED'].set_xticks(x_pos, [r'$E(x)\ll L$', r'$E(x)\sim L$'], fontsize = 9)
axes['ED'].set_ylabel('$ED_{a}$', fontsize  = 12) 
# axes['ED'].set_xlabel('state', fontsize  = 12)
axes['ED'].set_xlabel('state space regions', fontsize  = 12)
axes['ED'].tick_params(axis='y', which='both', labelsize  = 12)
axes['ED'].yaxis.set_major_formatter('{:.0f}'.format)

# Adding colorbar using magma
sm = plt.cm.ScalarMappable(cmap=magma, norm=norm)
sm.set_array([])  # empty array for the scalar mappable
cbar = fig.colorbar(sm, axes['colorbar'])
cbar.set_label(r'$\mathcal{H}(\mathcal{A}|x)$', fontsize = 12)
cbar.outline.set_edgecolor('none')
axes['colorbar'].set_yticks([1.54, 5.54] ,['1.54','5.54'], fontsize = 12)

plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.3)  # Adjust margins

fig = plt.gcf()  # Get the current figure
fig.set_size_inches(15, 2.5)  # Set the size in inches
plt.show()
