import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp

# Extend the simulation time for training and testing
extended_simulation_time = 1200  # 1000 for training + 100 for testing

# Integrate the Lorenz system for the extended simulation time
extended_solution = solve_ivp(lorenz, [0, extended_simulation_time], initial_conditions, args=(sigma, rho, beta),
                              dense_output=True, t_eval=np.arange(0, extended_simulation_time, dt))

# Extract and normalize the extended Lorenz components as input time series
extended_lorenz_x = (extended_solution.y[0] - np.min(extended_solution.y[0])) / (np.max(extended_solution.y[0]) - np.min(extended_solution.y[0])) * amplitude
extended_lorenz_y = (extended_solution.y[1] - np.min(extended_solution.y[1])) / (np.max(extended_solution.y[1]) - np.min(extended_solution.y[1])) * amplitude
extended_lorenz_z = (extended_solution.y[2] - np.min(extended_solution.y[2])) / (np.max(extended_solution.y[2]) - np.min(extended_solution.y[2])) * amplitude

# Parameters
num_neurons = 292
train_time = 900
dt = 1.5
threshold = -55.5
rest_potential = -65.0
refractory_period = 5.0
# Initialize synaptic weights
synaptic_weights = np.random.rand(num_neurons, num_neurons) * 0.1
# STDP parameters
A_plus = 0.005  # Learning rate for potentiation
A_minus = 0.05  # Learning rate for depression
tau_plus = 20.0  # Time constant for potentiation (ms)
tau_minus = 10.0  # Time constant for depression (ms)

# Reset the variables
voltages = np.ones(num_neurons) * rest_potential
last_spike_time = -np.ones(num_neurons) * refractory_period
spike_trains = []
last_pre_spike = np.zeros((num_neurons, num_neurons))  # Last spike time of presynaptic neuron for each synapse

# Extended simulation loop with STDP learning and 3D Lorenz input
for t in range(extended_simulation_time):
    # Get input current from the normalized Lorenz components time series
    input_current_x = extended_lorenz_x[t] + np.random.randn(neurons_per_component) * 2.0
    input_current_y = extended_lorenz_y[t] + np.random.randn(neurons_per_component) * 2.0
    input_current_z = extended_lorenz_z[t] + np.random.randn(num_neurons - 2 * neurons_per_component) * 2.0

    input_current = np.concatenate([input_current_x, input_current_y, input_current_z])

    # Update voltages and apply STDP learning for the first 1000 timesteps
    for i in range(num_neurons):
        if t - last_spike_time[i] < refractory_period:
            continue
        voltages[i] += dt * (input_current[i] - (voltages[i] - rest_potential))

        # Check for spikes
        if voltages[i] >= threshold:
            spike_trains.append((t, i))
            voltages[i] = rest_potential
            last_spike_time[i] = t

            # Apply STDP
            if t < train_time:  # Apply STDP only during the training period
                for j in range(num_neurons):
                    if i != j:  # No STDP for self-connections
                        delta_t = t - last_pre_spike[j, i]
                        if delta_t > 0:  # Potentiation
                            synaptic_weights[j, i] += A_plus * np.exp(-delta_t / tau_plus)
                        elif delta_t < 0:  # Depression
                            synaptic_weights[j, i] += A_minus * np.exp(delta_t / tau_minus)

            # Update the last spike time of the postsynaptic neuron for all synapses
            last_pre_spike[:, i] = t

# Plotting activity distribution for the last 100 timesteps (1000 to 1100)
plt.figure(figsize=(10, 6))
for spike in spike_trains:
    if spike[0] >= 1000:  # Plot only the last 100 timesteps
        plt.plot([spike[0], spike[0]], [spike[1] + 0.5, spike[1] - 0.5], 'k')
# plt.title('Activity Distribution of RSNN with 3D Lorenz Input after STDP Learning')
plt.xlabel('Time (ms)')
plt.ylabel('Neuron Index')
plt.yticks(range(0, num_neurons, 25))
plt.tight_layout()

# Save the plot in a variable
plt_buffer_stdp = plt.gcf()
plt.close()  # Close the plot to avoid displaying it twice

plt_buffer_stdp
