#
# SpatioTemporal Local Neural Propagation - Version 1.0
# General recurrent neural network given by a digraph
# Spatiotemporal Local Propagation (STLP)
# Hamiltonian Sign Flip Policy  (HSF)
# 25th of September, 2023
# Alessandro Betti and Marco Gori
# submission to ICLR-2024
#
import numpy as np
import matplotlib.pyplot as plt
import math
import random
#
def Target(
        T,      # horizon (sec)
        m,      # number of samples per period
        c,      # number of cycles
        sel,    # choose the type of signal: sinusoindal, square (periodic ext), sin-const (periodic ext)
        io      # io = 0 refers to outout signal, whereas io = 1 refers to input signal
):  # set the target
    #
    # Setting paameters used later for target construction
    #
    if (io == 0):
            f = 0.001
    else:
            f = 0.01
    #
    tau = T/(c*m)  # set the quantization time for ODE (sec)
    z = np.zeros(c*m)
    amp = 1  # set the amplitude of signals
    omega = 2 * np.pi * f  #
    c1 = 20000000;
    c2 = 4;
    c3 = 2;
    c4 = 1;
    c5 = 1;
    smooth_tau = 10000;
    constant_value = -0.5
    decay_factor=0.0001
    #
    if sel == 'constant': 
        for k in range(0, c * m):
            z[k] = constant_value
        return z
    if sel == 'sinusoindal':  
        for k in range(0, c * m):
            z[k] = amp * np.sin(omega * tau * k) 
        return z
    if sel == 'g-sinusoindal':  
        for k in range(0, c * m):
            z[k] = amp * np.sin(omega * tau * k) 
        return z
    if sel == 'grad-sim':  
        for k in range(0, c * m):
            z[k] = amp * (1 - np.exp(-tau * k / smooth_tau)) * np.sin(omega * tau * k)
        return z
    if sel == 'cosinusoidal':  
        for k in range(0, c * m):
            z[k] = amp * np.cos(omega * tau * k)
        return z
    if sel == 'cos-constant': 
        for k in range(0, c * m):
            if (k < c * m // 2):
                z[k] = amp * np.cos(omega * tau * k)
            else:
                z[k] = z[c * m // 2]
        return z
    elif sel == 'zero': 
        for k in range(0, c * m):
            z[k] = 0
        return z
    elif sel == 'square':
        SqAmpl = 1
        for j in range(0, c):
            for k in range(0, m):
                z[j * m + k] = SqAmpl * (2 * (round(k / m)) - 1)
        return z
    elif sel == 'sin-constant':
        for j in range(0, c):
            for k in range(0, m // 3):
                z[j * m + k] = amp * np.cos(omega * tau * (m + k))
            for k in range(m // 3, 2 * (m // 3)):
                z[j * m + k] = -1.5
            for k in range(2 * (m // 3), m - 1):
                z[j * m + k] = amp * np.cos(0.3 * omega * tau * (m + k))
                #
        return z  
    elif sel == 'hybrid':
        for j in range(0, c):
            for k in range(0, m // 3):
                z[j * m + k] = amp * np.cos(omega * tau * (j * m + k))
            for k in range(m // 3, 2 * (m // 3)):
                z[j * m + k] = -1.0
            for k in range(2 * (m // 3), m):
                z[j * m + k] = amp * np.cos(0.3 * omega * tau * (j * m + k))
                #
        return z  # return the signal 
    else:
        exit()
#
#  Neuronal function models
#

def Sigma(u):
    return np.tanh(u)

def DSigma(u):
    return (1-Sigma(u))*(1+Sigma(u))

def SigmaT(u):
    return np.tanh(u)

def DSigmaT(u):
    return (1-Sigma(u))*(1+Sigma(u))

def Neural_Propagation(T,               # time horizon
                       n,               # number of quantization points
                       m,               # number of neurons
                       lq_args,         # lq arguments: w_range,b_range,x_range,q,r_w,r_0w,r_0x
                       u,               # input 
                       z                # target vector
                       ):
    ##########################################################################################################
    # VARIABLE ALLOCATION
    ##########################################################################################################
    # input to both nets
    sx_dim = (n,m)
    sp_x_dim = (n,m)
    sb_dim = (n,m) 
    sw_dim = (n,m,m)
    sb_dim = (n,m)
    # states
    f = np.zeros(n)
    x = np.zeros(sx_dim)
    w = np.zeros(sw_dim)
    b = np.zeros(sb_dim)
    p_x = np.zeros(sp_x_dim)
    p_w = np.zeros(sw_dim)
    p_b = np.zeros(sb_dim)
    ############################################################################################
    # time derivatives
    ############################################################################################
    x_dot = np.zeros(sx_dim)
    w_dot = np.zeros(sw_dim)
    b_dot = np.zeros(sb_dim) 
    #
    p_x_dot = np.zeros(sx_dim)
    p_w_dot = np.zeros(sw_dim)
    p_b_dot = np.zeros(sb_dim)
    ############################################################################################
    # INITIALIZATION
    ############################################################################################
    #
    # velocity regularization terms
    #
    w_dot_square = np.zeros(n)  # weight velocity regularization term
    w_square = np.zeros(n)      # weights regularization term
    #                  
    square_state = np.zeros(n)
    square_costate = np.zeros(n)
    overall_x_square = np.zeros(n)
    overall_p_square = np.zeros(n)
    se_tm = np.zeros(n)
    ###################################################################################################  
    SignFlip = np.zeros(n)                      # Sign flip array
    FlipCounter = np.zeros(n)
    #
    L = np.zeros(n)                             # Lagrangian and Hamiltonian 
    H = np.zeros(n)
    #
    # Euler's numerical approximation is used
    #
    tau = T / n                       # set the quantization step
    #
    # lq arguments
    #
    w_range = lq_args[0];
    b_range = lq_args[1];
    x_range = lq_args[2];
    q = lq_args[3];
    r_w = lq_args[4];
    r_0w = lq_args[5];
    r_0x = lq_args[6]
    #
    # weight initialization
    #
    w = 2*w_range*np.random.rand(n,m,m)-1
    b = 2*b_range*np.random.rand(n,m)-1
    #
    alpha = 1
    #
    # State initialization
    #
    x = x_range*np.random.rand(n,m)
    # Main (Poincarè-Bendixson initialization) 
    #
    SignFlip[0] = +1
    FlipCounter[0] = 0
    flip_period = 1
    #########################################################################################################
    # Processing loop
    #########################################################################################################
    for t in range(n-1):
        #
        # Intermediate values
        #
        w_dot_square[t] = 0
        w_square[t] = 0
        overall_x_square[t] = 0
        overall_p_square[t] = 0
        for i in range(m):
                f[i] = 0
                for j in range(m):
                        f[i] +=  w[t,i,j]*x[t,j]
                        w_dot_square[t] += w_dot[t,i,j]**2
                        w_square[t] +=  w[t,i,j]**2
                        overall_p_square[t] += p_w[t,i,j]**2
                f[i] += b[t,i] * u[t] 
                w_dot_square[t] +=  b_dot[t,i]**2
                w_square[t] +=  b[t,i]**2
                overall_x_square[t] += x[t,i]**2
                overall_p_square[t] += p_x[t,i]**2
                overall_p_square[t] += p_b[t,i]**2
        overall_x_square[t] += w_square[t]
        #
        p_x_dot[t,0] =  q * (z[t] - x[t,0])
        #
        # Hamilton's Local StatioTemporal Propagation - LSTP
        #
        for i in range(m):
                x_dot[t,i] = SignFlip[t]*alpha*(-x[t,i]+Sigma(f[i]))                                               #(ST)LP1
                x[t+1,i] = x[t,i] + tau*x_dot[t,i]
                p_b_dot[t,i] = - SignFlip[t]*(alpha*DSigma(f[i])*p_x[t,i]*u[t] + r_0w*b[t,i])
                p_b[t+1,i] = p_b[t,i] + tau*p_b_dot[t,i] 
                p_x_dot[t,i] +=  alpha*p_x[t,i] - r_0x * x[t,i]
                b_dot[t,i] = -SignFlip[t]*p_b[t,i]/r_w
                b[t+1,i] = b[t,i] + tau*b_dot[t,i]
                for j in range(m):
                        p_x_dot[t,i] += - alpha*DSigma(f[j])*p_x[t,j]*w[t,j,i]                                    # (ST)LP2
                        p_w_dot[t,i,j] = - SignFlip[t]*(alpha*DSigma(f[i])*p_x[t,i]*x[t,j] + r_0w*w[t,i,j])       # (ST)P3
                        p_w[t+1,i,j] = p_w[t,i,j] + tau*p_w_dot[t,i,j]
                        w_dot[t,i,j] = -SignFlip[t]*p_w[t,i,j]/r_w                                                # Control
                        w[t+1,i,j] = w[t,i,j] + tau*w_dot[t,i,j]
                p_x_dot[t,i] = SignFlip[t]*p_x_dot[t,i]
                p_x[t+1,i] = p_x[t,i] + tau*p_x_dot[t,i]
        #
        # Sign flip strategy to move into the Hamiltonian Track
        #
        if t%flip_period==0:
            SignFlip[t+1] = (-1)  * SignFlip[t]
            FlipCounter[t+1] = FlipCounter[t] + 1
        else:
            SignFlip[t+1] = SignFlip[t]
            FlipCounter[t+1] = FlipCounter[t]
        #
        # Lagrangian and Hamiltonian computation 
        #
        tge = z[t] - x[t,0]
        fitting = 0.5*q*tge**2
        regularization =  0.5*(r_0x*overall_x_square[t] + r_0w*w_square[t] + r_w*w_dot_square[t])
        L[t] = fitting + regularization
        px_terms = 0; pw_terms = 0; pb_terms = 0
        for i in range(m):
                px_terms += p_x[t,i]*x_dot[t,i]
                pb_terms += p_b[t,i]*b_dot[t,i]
                for j in range(m):
                        pw_terms += p_w[t,i,j]*w_dot[t,i,j]
                        
        H[t] = L[t] + px_terms + pw_terms + pb_terms
        if (t%5500==0): print(t/n,z[t],x[t,0],tge,L[t],H[t],SignFlip[t])
    return SignFlip,FlipCounter,w_square, x, p_x, L, H, w, b

  

##########################################################################
# Main program
##########################################################################
horizon = 21600
point_T = 43200
period=10
frequency = 50
n_points = point_T * period  # set the number of points
#
# Initial values: weights, bias, state
#
# network parameters
#
n_neurons = 5
w_init = 5
b_init = 5
x_init = 1
#
accuracy = 100
reg_w=0.1
reg_0w=1
reg_0x =1
lq = [w_init, b_init, x_init, accuracy, reg_w, reg_0w, reg_0x]
#
# Select the input
#
choice = 'zero'
#choice = 'constant'
#choice='sinusoindal'
#choice='g-sinusoindal'
#choice='gg-sinusoindal'
#choice='cosinusoidal'
#choice='square'
#choice='cos-constant'
#choice='grad-sim'
#choice: 'sin-constant'
#choice = 'hybrid'

io_sel = 1
in_signal = Target(horizon, point_T, period, choice, io_sel)

#
# Select the target
#
# choice = 'zero'
#choice = 'constant'
choice='sinusoindal'
#choice='g-sinusoindal'
#choice='gg-sinusoindal'
#choice='cosinusoidal'
#choice='square'
#choice='cos-constant'
#choice='grad-sim'
#choice: 'sin-constant'
#choice = 'hybrid'

io_sel = 0
target = Target(horizon, point_T, period, choice, io_sel)

SF,FC,WeightSq,State, Costate, Lagrangian, Hamiltonian, Weights, Biases = Neural_Propagation(
    horizon,
    n_points,
    n_neurons,
    lq,
    in_signal,
    target
)

#
# save results onto files
#
f = open("parameters.txt", "w")
f.write(f" Experiment Main Parameters \n number of neurons: {n_neurons}\n accuracy = {accuracy}\n weight reg = {reg_0w}\n dot_weight reg = {reg_w}\n")
f.close()
#
f = open("sw.txt", "w")
for d in WeightSq:
    f.write(f"{d}\n")
f.close()
#
f = open("in_signal.txt", "w")
for d in in_signal:
    f.write(f"{d}\n")
f.close()
#
f = open("target.txt", "w")
for d in target:
    f.write(f"{d}\n")
f.close()
#
f = open("State.txt", "w")
for d in State[0:n_points,0]:
    f.write(f"{d}\n")
f.close()
#
f = open("Costate.txt", "w")
for d in Costate[0:n_points,0]:
    f.write(f"{d}\n")
f.close()
#
f = open("L.txt", "w")
for d in Lagrangian:
    f.write(f"{d}\n")
f.close()
#
f = open("H.txt", "w")
for d in Hamiltonian:
    f.write(f"{d}\n")
f.close()
#
f = open("SignFlip.txt", "w")
for d in SF:
    f.write(f"{d}\n")
f.close()
#
f = open("FlipCounter.txt", "w")
for d in FC:
    f.write(f"{d}\n")
f.close()
#
# plot results
#
#x1 = State[0:n_points,0]
time = np.linspace(0, horizon, n_points)
#plt.plot(time, cs, 'g-', linewidth=2, label='(clock)')
plt.plot(time, target, 'r-', linewidth=2, label='z (target)')
plt.plot(time, WeightSq, 'k-', linewidth=2, label='square of the weights')
plt.plot(time, State[0:n_points,0], 'b-', linewidth=2, label='x')
# plt.plot(time,State2,'y-',linewidth=2,label='x')
# plt.plot(time,Costate1,'g-',linewidth=2,label='x')
# plt.plot(time,Weights,'y-',linewidth=2,label='Weights')
#plt.plot(time,FocusOfAttention,'c-',linewidth=2,label='FocusOfAttention')
#plt.plot(time,Lagrangian,'k-',linewidth=2,label='L')
#plt.plot(time,Hamiltonian,'c-',linewidth=2,label='H')
# plt.xlabel('time (sec)')
# plt.ylabel('x(t)')
# plt.ylabel('p(t),f(t)')
plt.legend(['accuracy=80000,reg_w =200,reg_0w=200,point_T=14400,per=10,freq= 250'])
plt.show()

