import numpy as np
import matplotlib.pyplot as plt
from brian2 import *

# Define the Lorenz96 system
def lorenz96(X, t):
    """ The Lorenz 96 system """
    N = len(X)
    dxdt = np.zeros(N)
    for i in range(N):
        dxdt[i] = (X[(i+1)%N] - X[i-2]) * X[i-1] - X[i] + 8.17
    return dxdt

# Set the parameters
N = 100  # number of neurons
tau_mean = 10*ms  # mean membrane time constant
tau_std = 2*ms  # standard deviation of membrane time constant
tau = np.random.normal(loc=tau_mean, scale=tau_std, size=N)*ms  # membrane time constants
tau_p = 20*ms  # STDP potentiation time constant
tau_n = 20*ms  # STDP depression time constant
w_min = 0  # minimum synaptic weight
w_max = 1  # maximum synaptic weight
alpha = 0.01  # regularization coefficient for readout weights
beta = 0.001  # learning rate for readout weights
T = 10000*ms  # simulation time
dt = 1*ms  # simulation time step
t_train = 8000*ms  # training time
t_test = 2000*ms  # testing time

# Define the spiking neuron model
eqs = '''
dv/dt = (-v + I)/tau : 1
I : 1
tau : second
'''
neurons = NeuronGroup(N, eqs, threshold='v>1', reset='v=0', method='euler')
neurons.v = 0
neurons.tau = tau

# Define the STDP synapse model
eqs_synapse = '''
w : 1
'''
synapses = Synapses(neurons, neurons, model=eqs_synapse, on_pre='x_pre', on_post='x_post')
synapses.connect(condition='i!=j')
synapses.w = 'rand()*w_max'

# Define the input current
input_current = TimedArray(np.zeros(N), dt=dt)

# Define the spiking neural network model
spike_monitors = SpikeMonitor(neurons)
spike_monitors_G = [SpikeMonitor(neurons, indices=[i]) for i in range(N)]

# Define the readout layer
readout = NeuronGroup(1, 'v = 0 : 1', method='euler')
output_synapses = Synapses(neurons, readout, model='w : 1', on_pre='v_post += w')
output_synapses.connect()
output_synapses.w = 'rand()'

# Run the simulation
run(T)

# Compute the input currents
input_current_train = TimedArray(lorenz96(np.zeros(N), np.arange(0, t_train/second, dt/second)), dt=dt)
input_current_test = TimedArray(lorenz96(np.zeros(N), np.arange(t_train/second, T/second, dt/second)), dt=dt)

# Train the readout layer
spike_counts_train = np.array([len(spike_monitors_G[i].t) for i in range(N)])
output_spikes_train = np.dot(synapses.w, spike_counts_train)
output_spikes_train -= np.mean(output_spikes_train


# Normalize the output spikes
output_spikes_train /= np.std(output_spikes_train)

# Compute the target output
target_output_train = lorenz96(np.zeros(1), np.arange(0, t_train/second, dt/second))[:, np.newaxis]

# Initialize the readout weights
W = np.random.randn(N, 1)

# Train the readout weights using ridge regression
for i in range(int(t_train/dt)):
    x = output_spikes_train[i, np.newaxis]
    y = target_output_train[i, np.newaxis]
    W += beta*np.dot((y - np.dot(x.T, W)).T, x) - alpha*W
readout.v = 0

# Test the spiking neural network
spike_counts_test = np.array([len(spike_monitors_G[i].t) for i in range(N)])
output_spikes_test = np.dot(synapses.w, spike_counts_test)
output_spikes_test -= np.mean(output_spikes_test)
output_spikes_test /= np.std(output_spikes_test)

# Compute the target output
target_output_test = lorenz96(np.zeros(1), np.arange(t_train/second, T/second, dt/second))[:, np.newaxis]

# Compute the predicted output
predicted_output_test = np.dot(output_spikes_test, W)

# Compute the mean squared error
mse_test = np.mean((target_output_test - predicted_output_test)**2)
print('MSE:', mse_test)

# Plot the results
plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(np.arange(0, t_train/second, dt/second), target_output_train, label='Target')
plt.plot(np.arange(0, t_train/second, dt/second), output_spikes_train, label='Output')
plt.legend()
plt.ylabel('Input/Output')
plt.title('Training')

plt.subplot(2, 1, 2)
plt.plot(np.arange(t_train/second, T/second, dt/second), target_output_test, label='Target')
plt.plot(np.arange(t_train/second, T/second, dt/second), predicted_output_test, label='Predicted')
plt.legend()
plt.xlabel('Time (s)')
plt.ylabel('Input/Output')
plt.title('Testing')

plt.show()
