import numpy as np
import matplotlib.pyplot as plt
from brian2 import *

# Define the SHD dataset
data = np.load('shd.npy')
x_train = data[:10000, :-1]
y_train = data[:10000, -1]
x_test = data[10000:, :-1]
y_test = data[10000:, -1]

# Define the model parameters
N = x_train.shape[1]
M = 100
tau_m = 10*ms + 90*ms*np.random.rand(M)
tau_syn_e = 1*ms
tau_syn_i = 2*ms
tau_stdp = 20*ms + 40*ms*np.random.rand(M)
A_plus = 0.01 + 0.05*np.random.rand(M)
A_minus = 0.01 + 0.05*np.random.rand(M)
w_min = 0*mvolt
w_max = 2*mV

# Define the network equations
input_layer = SpikeGeneratorGroup(N, np.arange(N), x_train.flatten()*Hz)
hidden_layer = NeuronGroup(M, '''
    dv/dt = (-v + ge - gi) / tau_m : volt
    dge/dt = -ge / tau_syn_e : volt
    dgi/dt = -gi / tau_syn_i : volt
    ''')
output_layer = NeuronGroup(1, '''
    dv/dt = (-v + ge) / tau_m[0] : volt
    dge/dt = -ge / tau_syn_e : volt
    ''')

# Define the synaptic connections
input_hidden = Synapses(input_layer, hidden_layer, on_pre='ge += w')
hidden_output = Synapses(hidden_layer, output_layer, on_pre='ge += w')
input_hidden.connect(j='i')
input_hidden.w = 'rand() * (w_max - w_min) + w_min'
hidden_output.connect()
hidden_output.w = 0*mV
S = SpikeMonitor(hidden_layer)
trace = StateMonitor(output_layer, 'v', record=True)
stdp = STDP(hidden_output, A_plus=A_plus, A_minus=A_minus, tau_pre=tau_stdp, tau_post=tau_stdp)
stdp_source = PoissonGroup(1, 10*Hz)
stdp_connection = Synapses(stdp_source, hidden_output, on_pre='w += 0.1*mV')
stdp_connection.connect(j='i')

# Define the monitors
output_spikes = SpikeMonitor(output_layer)

# Define the training input and target signals
t_train = x_train.shape[0] * defaultclock.dt
x_train *= Hz
y_train *= mV

# Define the testing input and target signals
t_test = x_test.shape[0] * defaultclock.dt
x_test *= Hz
y_test *= mV

# Define the initial conditions
hidden_layer.v = 'v_rest + (randn() * delta_t)'
output_layer.v = 'v_rest + (randn() * delta_t)'

# Run the simulation
run(t_train)
hidden_output.pre_spike_hook = stdp_source.set_spikes
run(t_test)

# Compute the output weights
w_out = np.dot(output_spikes.count / t_test * 1000, x_train) @ np.linalg.inv(np.dot(x_train.T, x_train))

# Compute the training and testing errors
y_train_pred = np.dot(output_spikes.count / t_test * 1000, x_train) @ w_out
train_error = np.sqrt(np.mean((y_train - y_train_pred)**2))
y_test_pred = np.dot(output_spikes.count / t_test * 1000, x_test) @ w_out
test_error = np.sqrt(np.mean((y_test - y_test_pred)**2))

#Plot the output
fig, ax = plt.subplots(2, 1, figsize=(12, 8))
ax[0].plot(np.arange(0, t_train/ms, defaultclock.dt/ms), trace.v[0], label='Predicted')
ax[0].plot(np.arange(0, t_train/ms, defaultclock.dt/ms), y_train, label='Target')
ax[0].set_xlabel('Time (ms)')
ax[0].set_ylabel('Voltage (mV)')
ax[0].set_title('Training Set')
ax[0].legend()
ax[1].plot(np.arange(t_train/ms, (t_train+t_test)/ms, defaultclock.dt/ms), trace.v[0], label='Predicted')
ax[1].plot(np.arange(t_train/ms, (t_train+t_test)/ms, defaultclock.dt/ms), y_test, label='Target')
ax[1].set_xlabel('Time (ms)')
ax[1].set_ylabel('Voltage (mV)')
ax[1].set_title('Test Set')
ax[1].legend()
plt.show()

#Print the errors
print(f'Training error: {train_error:.4f} mV')
print(f'Test error: {test_error:.4f} mV')