# -*- coding: utf-8 -*-
"""
Multitimescale
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

plt.rc('font', family='sans-serif', serif=['Arial'])
plt.rc('axes', edgecolor='black')
plt.rc('xtick', color='black')
plt.rc('ytick', color='black')
plt.rc('grid', color='black')

# -----------------------------
# 1)
# -----------------------------
n_steps1 = 300
x = 1.0
T1 = np.ones(n_steps1)
T1[144:155] = -1
T1[154:] = 1

alpha_fast, alpha_slow, fast_decay = 0.3, 0.2, 0.8
w_fast = w_slow = w0 = 0.0
alpha0, g0, g_decay, g_alpha = 0.2, 1.0, 0.8, 0.4
w = g = g0

# outputs
y0_log    = np.zeros(n_steps1)
y_log     = np.zeros(n_steps1)
y_mod_log = np.zeros(n_steps1)

# logs
w0_log = np.zeros(n_steps1)  
wf_log = np.zeros(n_steps1)   
ws_log = np.zeros(n_steps1)   
wt_log = np.zeros(n_steps1)   
wg_log = np.zeros(n_steps1)  
gg_log = np.zeros(n_steps1)   
g_log  = np.zeros(n_steps1)   

for t in range(n_steps1):
    w0_log[t] = w0
    wf_log[t] = w_fast
    ws_log[t] = w_slow
    wt_log[t] = w_fast + w_slow
    wg_log[t] = w
    gg_log[t] = g

    # Slow
    y0 = w0 * x
    err0 = T1[t] - y0
    y0_log[t] = y0
    w0 += alpha_slow * err0 * x

    # Fast + Slow
    w_total = w_fast + w_slow
    y_log[t] = w_total * x
    err1 = T1[t] - w_total * x
    w_fast = fast_decay * w_fast + alpha_fast * err1 * x
    w_slow += alpha_slow * err1 * x

    # Gain-modulated
    ygm = g * w * x
    y_mod_log[t] = ygm
    err2 = T1[t] - ygm
    g = g_decay * g + (1 - g_decay) * g0 + g_alpha * abs(w * err2 * x)
    w += alpha0 * g * err2 * x
    g_log[t] = g

# -----------------------------
# 2)
# -----------------------------
dt = 0.1
n_steps2 = int(100 / dt)
t2 = np.arange(0, n_steps2) * dt

T2_series = np.ones(n_steps2) * 0.5
T2_series[int(44/dt):] = 1.0

num_runs = 1000
alpha2, g0_2, g_decay2, g_alpha2 = 0.1, 1.0, 0.5, 3.0
noise_std_g, noise_std_w = 0.025, 0.025

G   = np.zeros((num_runs, n_steps2))
W0  = np.zeros((num_runs, n_steps2))
Wg0 = np.zeros((num_runs, n_steps2))
Wm  = np.zeros((num_runs, n_steps2))

for run in range(num_runs):
    g2 = g0_2
    w2 = 0.5
    for t in range(n_steps2):
        g_noisy = g2 + noise_std_g * np.random.randn()
        w_noisy = w2 + noise_std_w * np.random.randn()
        err = T2_series[t] - (g_noisy * w_noisy * x)
        G[run, t]   = g_noisy
        W0[run, t]  = g0_2 * w_noisy
        Wg0[run, t] = (g_noisy - g0_2) * w_noisy
        Wm[run, t]  = g_noisy * w_noisy
        g2 = g_decay2 * g2 + (1 - g_decay2) * g0_2 + g_alpha2 * abs(w2 * err * x) * dt
        w2 += dt * alpha2 * g_noisy * err * x

G_mean, G_std       = G.mean(axis=0),   G.std(axis=0)
W0_mean, W0_std     = W0.mean(axis=0),  W0.std(axis=0)
Wg0_mean, Wg0_std   = Wg0.mean(axis=0), Wg0.std(axis=0)
Wm_mean, Wm_std     = Wm.mean(axis=0),  Wm.std(axis=0)

# -----------------------------
# -----------------------------
switch_idxs = np.where(np.diff(T1) != 0)[0]
t_first_pre  = int(switch_idxs[0])
t_first_post = t_first_pre + 1       
win = slice(144, 155)
g_max = float(np.max(g_log[win]))
t_gmax = 144 + int(np.argmax(g_log[win]))

FLATTEN_POWER = 2
scale_flat = max(g_max, 1.0) ** FLATTEN_POWER

def L_of(weff, T):
    return 0.5 * (T - weff)**2

def next_states_slow_from_logs(w0_now, T_now):
    err0 = T_now - w0_now * x
    w0_next = w0_now + alpha_slow * err0 * x
    return w0_now * x, w0_next * x

def next_states_fs_from_logs(wf_now, ws_now, T_now):
    wtot_now = wf_now + ws_now
    err1 = T_now - wtot_now * x
    wf_next = fast_decay * wf_now + alpha_fast * err1 * x
    ws_next = ws_now + alpha_slow * err1 * x
    return wtot_now * x, (wf_next + ws_next) * x

def next_states_gm_from_logs(w_now, g_now, T_now):
    y_now = g_now * w_now * x
    err2 = T_now - y_now
    w_next = w_now + alpha0 * g_now * err2 * x
    g_next = g_decay * g_now + (1 - g_decay) * g0 + g_alpha * abs(w_next * err2 * x)
    return y_now, g_next * w_next * x

t_ref = t_gmax
T_cur = float(T1[t_ref])

Weff_s_now, Weff_s_next = next_states_slow_from_logs(w0_log[t_ref], T_cur)
L_s_now_real  = L_of(Weff_s_now,  T_cur)
L_s_next_real = L_of(Weff_s_next, T_cur)

Weff_f_now, Weff_f_next = next_states_fs_from_logs(wf_log[t_ref], ws_log[t_ref], T_cur)
L_f_now_real  = L_of(Weff_f_now,  T_cur)
L_f_next_real = L_of(Weff_f_next, T_cur)

Weff_g_now, Weff_g_next = next_states_gm_from_logs(wg_log[t_ref], gg_log[t_ref], T_cur)
L_g_now_flat  = L_of(Weff_g_now,  T_cur) / scale_flat
L_g_next_flat = L_of(Weff_g_next, T_cur) / scale_flat

# -----------------------------
# Plotting
# -----------------------------
fig = plt.figure(figsize=(20, 5), facecolor='w')
gs = GridSpec(
    2, 3, figure=fig,
    width_ratios=[1, 1, 1.5],   
    height_ratios=[1, 1],     
    wspace=0.3,             
    hspace=0.2              
)

# A-t
ax1 = fig.add_subplot(gs[0, 0])
ax1.fill_between(t2, G_mean - G_std, G_mean + G_std, color='limegreen', alpha=0.25)
ax1.axhline(1.0, linestyle='--', color='slategray', linewidth=4)
ax1.plot(t2, G_mean, color='limegreen', linewidth=3)
ax1.set_xticks([]); ax1.set_yticks([])
ax1.text(
    t2[0]-1, 1.125, r"$g$", 
    va='bottom', ha='left',
    fontsize=26, color='limegreen'
)
ax1.text(
    t2[0]-1, 1.075, r"$g_0$", 
    va='bottom', ha='left',
    fontsize=26, color='slategray'
)
ax1.spines['top'].set_visible(False); 
ax1.spines['bottom'].set_visible(False);
ax1.spines['right'].set_visible(False)
ax1.spines['bottom'].set_linewidth(5); ax1.spines['left'].set_linewidth(5)
ax1.set_title('A)', loc='left', fontsize=32, fontweight='bold', x=-0.2, y=1.25)
ax1.set_ylabel('Gain', fontsize=28, labelpad=10) 


ax1.legend(fontsize=24, loc='upper right', frameon=False)

# B-r
col_T1 = 'salmon'
col_T2 = 'teal'

ax2 = fig.add_subplot(gs[:, 1])

l_slow,   = ax2.plot(y0_log,    linewidth=5, label='Slow weight',            color='indigo')
l_fs,     = ax2.plot(y_log,     linewidth=5, label='Fast-slow weights',      color='royalblue')
l_gain,   = ax2.plot(y_mod_log, linewidth=5, label='Gain-modulated weights', color='darkorange')


ax2.axhline(-1.0, color=col_T1, linestyle=':', linewidth=10, alpha=1)
ax2.axhline(+1.0, color=col_T2,         linestyle=':', linewidth=10, alpha=1)

ax2.axvspan(144, 155, color='lightgray', alpha=0.25)
ax2.axvline(t_gmax, color='black', linewidth=2.5, linestyle='--')

ax2.set_ylim(-1.5, 1.5)
ax2.set_xticks(np.arange(110, 201, 10))

ax2.set_yticks([-1.0, 0, 1.0])
ax2.set_yticklabels([r"$T_{2}$", "0", r"$T_{1}$"])

ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.spines['bottom'].set_linewidth(5)
ax2.spines['left'].set_linewidth(5)

ax2.xaxis.set_tick_params(width=2.5, length=6, direction='out', labelsize=20)  
ax2.yaxis.set_tick_params(width=2.5, length=6, direction='out')

ax2.tick_params(axis='y', labelsize=30, width=5, length=10) 
ax2.set_xlabel('Timesteps', fontsize=30)
ax2.set_ylabel('Output', fontsize=30)
ax2.set_xlim(135, 165)
ax2.set_title('B)', loc='left', fontsize=32, fontweight='bold', x=-0.35, y=1.1125)

# B-l
axC = fig.add_subplot(gs[:, 2])
Weff_grid = np.linspace(-2.2, 2.2, 600)

L_T2  = 0.5 * ( 1.0 - Weff_grid)**2
L_T1  = 0.5 * (-1.0 - Weff_grid)**2
axC.plot(Weff_grid, L_T2, linewidth=3, color=col_T2, label=r"$L$ (T2)")
axC.plot(Weff_grid, L_T1, linewidth=3, color=col_T1, label=r"$L$ (T1)")

L_T2_flat = L_T2 / scale_flat
L_T1_flat = L_T1 / scale_flat
axC.plot(Weff_grid, L_T2_flat, linewidth=3, color=col_T2, linestyle='--', label=r"$\tilde L$ (T2)")
axC.plot(Weff_grid, L_T1_flat, linewidth=3, color=col_T1, linestyle='--', label=r"$\tilde L$ (T1)")

c_slow = 'indigo'
c_fs   = 'royalblue'
c_gm   = 'darkorange'

axC.plot([Weff_s_now], [L_s_now_real],  marker='o', ms=20, color=c_slow)
axC.plot([Weff_f_now], [L_f_now_real],  marker='o', ms=20, color=c_fs)
axC.plot([Weff_g_now], [L_g_now_flat],  marker='o', ms=20, color=c_gm)

def arrow(ax, x0,y0, x1,y1, color):
    ax.annotate("", xy=(x1,y1), xytext=(x0,y0),
                arrowprops=dict(arrowstyle="->", lw=5, color=color, shrinkA=0, shrinkB=0))

arrow(axC, Weff_s_now, L_s_now_real,  Weff_s_next, L_s_next_real,  c_slow)
arrow(axC, Weff_f_now, L_f_now_real,  Weff_f_next, L_f_next_real,  c_fs)
arrow(axC, Weff_g_now, L_g_now_flat,  Weff_g_next, L_g_next_flat,  c_gm)

axC.text(-1.0, 0.07, r"$L_{T_{2}}$", ha='center', va='bottom', color=col_T1, fontsize=30)
axC.text(-1.25, 1.0, r"$\tilde{L}_{T_{1}}$", ha='center', va='bottom', color=col_T2, fontsize=30)
axC.text(+1.25, 1.0, r"$\tilde{L}_{T_{2}}$", ha='center', va='bottom', color=col_T1, fontsize=30)
axC.text(+1.0, 0.07, r"$L_{T_{1}}$", ha='center', va='bottom', color=col_T2, fontsize=30)

axC.set_xticks([-1.0, 1.0])
axC.set_xticklabels([r"$W^*_{T_{2}}$", r"$W^*_{T_{1}}$"], fontsize=20)
axC.tick_params(axis='x', width=5, length=10, direction='out')
axC.tick_params(axis='y', labelbottom=False, labelleft=False,width=0, length=0, direction='out')

axC.set_xlim(-2.2, 2.2)
axC.set_ylim(-0.05, 2)
axC.set_xlabel("Effective weight", fontsize=30)
axC.set_ylabel("Loss", fontsize=30 ,labelpad=10)
axC.spines['top'].set_visible(False); axC.spines['right'].set_visible(False)
axC.spines['bottom'].set_linewidth(5); axC.spines['left'].set_linewidth(5)

# lgd
handles = [l_slow, l_fs, l_gain]
labels  = ['Slow weight', 'Fast-Slow weights', 'Gain weight']
fig.legend(handles, labels, fontsize=22, loc='upper center', ncol=3, frameon=False, bbox_to_anchor=(0.635, 1.02))


# === A-b
ax3 = fig.add_subplot(gs[1, 0])
ax3.fill_between(t2[:-1], np.diff(W0_mean) - np.diff(W0_std), np.diff(W0_mean) + np.diff(W0_std), color='brown', alpha=0.2)
ax3.fill_between(t2[:-1], np.diff(Wg0_mean) - np.diff(Wg0_std), np.diff(Wg0_mean) + np.diff(Wg0_std), color='red', alpha=0.2)
ax3.fill_between(t2[:-1], np.diff(Wm_mean) - np.diff(Wm_std), np.diff(Wm_mean) + np.diff(Wm_std), color='darkorange', alpha=0.2)

ax3.plot(t2[:-1], np.diff(W0_mean),  color='brown',      linewidth=3, label=r'$w_{\rm slow}$')
ax3.plot(t2[:-1], np.diff(Wg0_mean), color='red',        linewidth=3, label=r'$w_{\rm fast}$')
ax3.plot(t2[:-1], np.diff(Wm_mean),  color='darkorange', linewidth=3, label=r'$W_{\rm eff}$')

ax3.set_xticks([]); ax3.set_yticks([])
ax3.set_xlim(42, 46)
ax3.spines['top'].set_visible(False); ax3.spines['right'].set_visible(False)
ax3.spines['bottom'].set_linewidth(5); ax3.spines['left'].set_linewidth(5)
ax3.set_xlabel('Time',   fontsize=30) 
ax3.set_ylabel(r'$\Delta$Weight', fontsize=28, labelpad=10)

ax3.text(ax3.get_xlim()[0]+0.2, ax3.get_ylim()[1]*0.725, r"$\Delta w_{\rm fast}$",
         color='red', fontsize=26, ha='left', va='top')
ax3.text(ax3.get_xlim()[0]+0.2, ax3.get_ylim()[1]*0.95, r"$\Delta W_{\rm eff}$",
         color='darkorange', fontsize=26, ha='left', va='top')
ax3.text(ax3.get_xlim()[0]+0.2, ax3.get_ylim()[1]*0.5, r"$\Delta w_{\rm slow}$",
         color='brown', fontsize=26, ha='left', va='top')

fig.align_xlabels([axC, ax2, ax3])

# plt.tight_layout()
plt.show()

# Save
filename = 'fig_multitimescaleflattened'
# fig.savefig(f'{filename}.svg', format='svg', dpi=600, bbox_inches='tight')
# fig.savefig(f'{filename}.png', dpi=600, bbox_inches='tight')
# fig.savefig(f'{filename}.pdf', format='pdf', dpi=600, bbox_inches='tight')

