import numpy as np
import matplotlib.pyplot as plt
from brian2 import *

# Set the random seed for reproducibility
np.random.seed(42)

# Define the Lorenz-63 data
sigma = 10
rho = 28
beta = 8/3
delta_t = 0.01
init = [-8, 8, 27]
t = np.arange(0, 100, delta_t)
n = len(t)
data = np.empty((n, 3))
data[0] = init
for i in range(n-1):
    x, y, z = data[i]
    data[i+1] = [x + sigma*(y - x)*delta_t, 
                 y + (x*(rho - z) - y)*delta_t, 
                 z + (x*y - beta*z)*delta_t]



# Define the parameters
N = 100  # Number of neurons in the recurrent layer
tau_mean = 10*ms  # Mean membrane time constant
tau_std = 2*ms  # Standard deviation of membrane time constant
v_rest = 0*mV  # Resting potential
v_thresh = 1*mV  # Threshold potential
v_reset = 0*mV  # Reset potential
refractory_period = 2*ms  # Refractory period
w_init = 0.1  # Initial connection weight
tau_syn = 5*ms  # Synaptic time constant
stdp_tau_pre_mean = 20*ms  # Mean time constant for pre-synaptic STDP rule
stdp_tau_pre_std = 5*ms  # Standard deviation of time constant for pre-synaptic STDP rule
stdp_tau_post_mean = 20*ms  # Mean time constant for post-synaptic STDP rule
stdp_tau_post_std = 5*ms  # Standard deviation of time constant for post-synaptic STDP rule
stdp_A_mean = 0.1  # Mean STDP learning rate
stdp_A_std = 0.01  # Standard deviation of STDP learning rate
readout_tau = 10*ms  # Time constant for the readout layer
learning_rate = 1e-3  # Learning rate for the readout layer

# Define the equations of the LIF neuron model
eqs = '''
dv/dt = (-v + I_syn) / tau : volt (unless refractory)
I_syn = w_syn * g_inh * (E_inh - v) + w_syn * g_exc * (E_exc - v) : volt
dg_inh/dt = -g_inh / tau_syn : 1
dg_exc/dt = -g_exc / tau_syn : 1
tau : second
'''

# Define the STDP rules
stdp_pre = '''
w_syn += A * exp(-(t - t_pre) / tau_pre) : 1
'''
stdp_post = '''
w_syn += A * exp(-(t_post - t) / tau_post) : 1 (event-driven)
'''

# Define the readout layer model
readout_eqs = '''
dv/dt = -v / readout_tau : volt
'''

# Define the input and output signals
dt = 0.01  # Time step for Lorenz-63 data
t_start = 0  # Start time for simulation
t_stop = (len(data)-1) * dt  # End time for simulation
t = np.arange(t_start, t_stop, dt) * second
x = data[:, 0]
input_signal = TimedArray(x, dt=dt*second)

# Generate the network
net = Network()

# Create the recurrent layer
tau_values = np.random.normal(loc=tau_mean, scale=tau_std, size=N)
neurons = NeuronGroup(N, model=eqs, threshold='v > v_thresh', reset='v = v_reset', refractory=refractory_period, method='euler',  dt=0.1*ms)
neurons.tau = tau_values * ms
neurons.v = v_reset
neurons.g_inh = 0
neurons.g_exc = 0

# Create the synapses
w_syn_init = np.random.normal(loc=w_init, scale=w_init/10)
synapses = Synapses(neurons, neurons, model='''w_syn : 1''', on_pre=stdp_pre, on_post=stdp_post)
synapses.connect()
synapses.w_syn = w_syn_init

# Create the readout layer
readout_tau = 10  # Time constant of the readout layer
readout_neuron = NeuronGroup(1, model=readout_eqs, method='exact')
readout_neuron.v = 0 * volt
readout_synapses = Synapses(neurons, readout_neuron, on_pre='v += w_syn')
readout_synapses.connect(j=np.arange(N))
readout = StateMonitor(readout_neuron, 'v', record=True)

# Run the simulation
net.add(neurons)
net.add(synapses)
net.add(readout_neuron)
net.add(readout_synapses)
net.add(readout)
net.run(t_stop)

# Train the readout layer
target_signal = TimedArray(data[:, 0], dt=dt*second)
spike_train = (np.array(neurons.tau / dt, dtype=int) * (input_signal / volt)).T
optimizer = SGD(learning_rate)
n_epochs = 10
loss = []
for epoch in range(n_epochs):
    optimizer.zero_grad()
    output_signal = np.array(readout.v[0])[:-1]
    output_signal = (output_signal - output_signal.mean()) / output_signal.std()
    target = target_signal.values[1:]
    loss.append(((output_signal - target)**2).mean())
    error = output_signal - target
    readout_errors = np.dot(error, spike_train)
    readout_synapses.w_syn -= learning_rate * readout_errors
    readout_synapses.w_syn[readout_synapses.w_syn < 0] = 0
    optimizer.step()

# Plot the results
plt.figure(figsize=(10, 6))
plt.plot(t, readout.v[0])
plt.plot(t, data[:, 0])
plt.legend(['Predicted', 'Target'])
plt.xlabel('Time (s)')
plt.ylabel('X')
plt.show()