from brian2 import *
import numpy as np
import matplotlib.pyplot as plt

# Define the Lorenz96 function
def lorenz96(x, t):
    F = 8
    N = len(x)
    dxdt = np.zeros(N)
    for i in range(N):
        dxdt[i] = (x[(i+1)%N] - x[i-2]) * x[i-1] - x[i] + F
    return dxdt

# Set the random seed
np.random.seed(12345)

# Define the simulation parameters
N = 1000
T = 5000
dt = 0.1
t_train = 4000
alpha = 0.001

# Define the neuron model and equations
tau = GammaDist(2.5, 0.03, size=N)*ms
v_rest = -65*mV
v_reset = -65*mV
v_th = -50*mV
delta_t = 2*mV
eqs_neurons = '''
dv/dt = (-(v - v_rest) + delta_t*exp((v - v_th) / delta_t)) / tau : volt (unless refractory)
'''

# Define the input layer
input_layer = PoissonGroup(N, rates=10*Hz)

# Define the output layer
output_layer = NeuronGroup(1, eqs_neurons, threshold='v>v_th', reset='v=v_reset', refractory=2*ms)

# Define the synapses
syn = Synapses(input_layer, output_layer, on_pre='v += 0.5*mV')
syn.connect()

# Define the STDP learning rules
tau_pre = GammaDist(1.5, 0.001, size=(N,1))*ms
tau_post = GammaDist(2.5, 0.03, size=(1,1))*ms
gamma_pre = GammaDist(1, 1, size=(N,1))
gamma_post = GammaDist(1, 0.2, size=(1,1))
stdp = STDP(syn, eqs='''dpre/dt = -pre / tau_pre : 1 (event-driven)
                        dpost/dt = -post / tau_post : 1 (event-driven)
                        ''',
            pre='pre = gamma_pre * clip(exp(-(v_pre - v_post)**2 / (2 * (5*mV)**2)), 0, 1)', 
            post='post = gamma_post * clip(exp(-(v_post - v_pre)**2 / (2 * (5*mV)**2)), 0, 1)',
            w='g_syn * (post - pre)',
            wmin=0, wmax=1)


# Define the monitors
input_spikes = SpikeMonitor(input_layer)
output_spikes = SpikeMonitor(output_layer)

# Define the initial conditions
v_init = 'v = v_rest + (randn() * delta_t)'
output_layer.run_regularly(v_init, dt=10*ms)

# Define the training input and target signals
x_train = np.zeros((int(t_train/dt), N))
for i in range(1, int(t_train/dt)):
    x_train[i] = x_train[i-1] + lorenz96(x_train[i-1], 0)*dt + 0.5*np.random.randn(N)*np.sqrt(dt)
y_train = np.zeros((int(t_train/dt), 1))
for i in range(1, int(t_train/dt)):
    y_train[i] = x_train[i-1,0]

# Define the testing input and target signals
x_test = np.zeros((int((T-t_train)/dt), N))
for i in range(1, int((T-t_train)/dt)):
    x_test[i] = x_test[i-1] + lorenz96(x_test[i-1], t_train+i*dt)*dt
y_test = x_test[:,0]

# Run the simulation
run(T*ms, report='text')

# Compute the output weights
w_out = np.dot(output_spikes.count / T * 1000, x_train[:,-100:]) @ np.linalg.inv(np.dot(x_train[:,-100:].T, x_train[:,-100:]))

# Compute the training and testing errors
y_train_pred = np.dot(output_spikes.count / T * 1000, x_train) @ w_out
train_error = np.sqrt(np.mean((y_train - y_train_pred)**2))
y_test_pred = np.dot(output_spikes.count / T * 1000, x_test) @ w_out
test_error = np.sqrt(np.mean((y_test - y_test_pred)**2))

# Plot the results
plt.figure(figsize=(16,6))
plt.subplot(1,2,1)
plt.plot(np.arange(0, t_train, dt), y_train, label='Target')
plt.plot(np.arange(0, t_train, dt), y_train_pred, label='Prediction')
plt.xlabel('Time (ms)')
plt.ylabel('Signal')
plt.title('Training Results (RMSE={:.4f})'.format(train_error))
plt.legend()
plt.subplot(1,2,2)
plt.plot(np.arange(t_train, T, dt), y_test, label='Target')
plt.plot(np.arange(t_train, T, dt), y_test_pred, label='Prediction')
plt.xlabel('Time (ms)')
plt.ylabel('Signal')
plt.title('Testing Results (RMSE={:.4f})'.format(test_error))
plt.legend()
plt.show()
