import numpy as np
import dill
import matplotlib.pyplot as plt
import seaborn as sns

#%% VARIABLES TO SET
number_of_ctx = 6

#%%LOADING DATA

name_file = 'Figure4'

with open(name_file + '_training.pkl', 'rb') as file:
    data = dill.load(file)

params, MSE, STD, STD_single, T_end, t_survival, rnn, fnn = data

with open(name_file + '_activity.pkl', 'rb') as file:
    data = dill.load(file)
    
Activity, H_actions = data

#%% PLOT FEATURES

pal_grey = ['f8f9fa', 'e9ecef', 'dee2e6', 'ced4da', 'adb5bd', '6c757d', '495057', '343a40']
pal_grey = [f"#{c}" for c in pal_grey]
color_palette = "#C55986"
color_palette_2 = "#76003A"

plt.rcParams['text.usetex'] = True
plt.rcParams['mathtext.fontset'] = 'stix'
plt.rcParams['font.family'] = 'STIXGeneral'

# set default style
sns.reset_defaults() # useful when adjusting style a lot
sns.set_theme(context="paper", style="ticks",
              # palette="Set2",
              palette=my_pal,
              rc={
              "pdf.fonttype": 42,  # embed font in output
              "svg.fonttype": "none",  # embed font in output
              "figure.facecolor": "white",
              "figure.dpi": 100,
              "axes.facecolor": "None",
              "axes.spines.left": True,
              "axes.spines.bottom": True,
              "axes.spines.right": False,
              "axes.spines.top": False,
          },
          )

sns.color_palette(pal_grey) 

    
#%% SETTING THE CONTEXTS

fig, axes = plt.subplot_mosaic([
    ["Circle", "Plus", "Diamond", "no", "Accuracy" , "Accuracy" ],
    ["Heart", "Square" , "Oval", "colorbar", "T_end" , "STD", ],
    ], gridspec_kw=dict(hspace=1, 
                      wspace=0.8,
                      width_ratios=[1,1,1, 0.1, 0.5, 0.5]),)

all_contexts = ["Circle", "Plus", "Diamond", "Heart", 'Square', 'Oval']

#Choose a rnadom trajectory to plot
traj = 0

context = []
for i in range(rnn.C):
    context.append(all_contexts[i])

# from matplotlib.colors import ListedColormap
from matplotlib.cm import magma
from matplotlib.colors import Normalize

# Use magma colormap to map continuous values, max is given by the maximum entropy
norm = Normalize(vmin=np.min(H_actions[traj, :,0,:]), vmax=np.max(H_actions[traj, :,0,:]))

# plotting points every delta = 3
delta = 3

for idx, name_ctx in enumerate(context):  
    colors = magma(norm(H_actions[traj, idx,0,:])).reshape(-1, 4)
    # Draw lines between points with color-dependent colors
    for i in range(int(np.floor(len(Activity[traj, idx, 0,:])/delta)) - 1):
        axes[name_ctx].plot(Activity[traj, idx, 0,delta*i:delta*i+2], Activity[traj, idx, 1,delta*i:delta*i+2], c=colors[delta*i], linewidth = 1,  zorder=2)
    axes[name_ctx].set_xlabel(r'$x_1$', fontsize = 12)
    axes[name_ctx].set_ylabel(r'$x_2$', fontsize = 12)
    axes[name_ctx].set_xticks([-1,0,1], ['-1','0','1'], fontsize = 12 )
    axes[name_ctx].set_yticks([-1,0,1], ['-1','0','1'], fontsize = 12)
    axes[name_ctx].set_title(name_ctx ,  fontsize = 12)

x_pos = [1.5*i for i in range(rnn.C)]
# I want to treat all batches and training as same samples so i reshape the survival time
Survival = t_survival.reshape((params['Nbatches'] * params['Naverage'], params['C']))
mean_data = np.mean(Survival, axis = 0)
std_data = np.std(Survival, axis = 0)/np.sqrt(params['Nbatches'] * params['Naverage'])

axes['Accuracy'].bar(x_pos, mean_data, yerr = std_data, width = 0.5, align = 'center', color = color_palette)
axes['Accuracy'].set_xticks(x_pos)
axes['Accuracy'].set_xticklabels(all_contexts, fontsize = 12, rotation = 60)
axes['Accuracy'].set_ylabel(r'${t}_{end}^{c}$', fontsize = 12)
axes['Accuracy'].set_yticks([0,400,600] ,['0', '400', '600'], fontsize = 12)
axes['Accuracy'].set_title('Mean accuracy',  fontsize = 12)


axes['T_end'].plot(np.mean(T_end, 0), alpha=0.9, color = color_palette) 
axes['T_end'].fill_between(x = np.arange(params['Nepochs']), y1 = np.mean(T_end, 0) - np.std(T_end, 0)/np.sqrt(params['Naverage']), y2 = np.mean(T_end, 0) + np.std(T_end, 0)/np.sqrt(params['Naverage']), alpha = 0.2, color = color_palette)
axes['T_end'].set_xlabel('epochs', fontsize = 12)
axes['T_end'].set_yticks([0,400,600] ,['0', '400', '600'], fontsize = 12)
axes['T_end'].set_xticks([0,100] ,['0',  '100'], fontsize = 12)
axes['T_end'].set_ylabel('$t_{end}$', fontsize = 12) 

axes['STD'].plot(np.mean(STD_single,0), alpha=0.9, color = color_palette, label = r'$\langle\sigma\rangle$') 
axes['STD'].fill_between(x = np.arange(params['Nepochs']), y1 = np.mean(STD_single, 0) - np.std(STD_single, 0)/np.sqrt(params['Naverage']), y2 = np.mean(STD_single, 0) + np.std(STD_single, 0)/np.sqrt(params['Naverage']), alpha = 0.1, color =  color_palette_2) 
axes['STD'].set_xlabel('epochs', fontsize = 12)
axes['STD'].set_xticks([0,100] ,['0',  '100'], fontsize = 12)
axes['STD'].set_yticks([0,0.3] ,['0', '0.3'], fontsize = 12)
axes['STD'].set_ylabel(r'$\langle\sigma\rangle$', fontsize = 12) 


# Adding colorbar using magma
sm = plt.cm.ScalarMappable(cmap=magma, norm=norm)
sm.set_array([])  # empty array for the scalar mappable
cbar = fig.colorbar(sm, axes['colorbar'])
cbar.set_label(r'$\mathcal{H}(\mathcal{A}|x)$',  fontsize = 12)
axes['colorbar'].set_yticks([0,5.54] ,['0', '5.54'], fontsize = 12)
cbar.outline.set_edgecolor('none')

axes['no'].tick_params(axis='both', which='both', bottom=False, left=False, labelbottom=False, labelleft=False)
axes['no'].spines['top'].set_visible(False)
axes['no'].spines['right'].set_visible(False)
axes['no'].spines['bottom'].set_visible(False)
axes['no'].spines['left'].set_visible(False)

fig = plt.gcf()  # Get the current figure
fig.set_size_inches(10, 4)  # Set the size in inches
plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.2)  # Adjust margins

plt.show()

