import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from skopt import gp_minimize
from skopt.space import Real, Integer
from skopt.utils import use_named_args
from brian2 import *

# Load the SHD dataset
data = np.loadtxt('SHD.txt', delimiter=',')
x = data[:, :-1]
y = data[:, -1]
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

# Define the objective function
@use_named_args([
    Real(0.1, 10.0, name='tau_mem'),
    Real(0.1, 10.0, name='tau_syn_e'),
    Real(0.1, 10.0, name='tau_syn_i'),
    Integer(10, 100, name='n_hidden'),
    Real(0.0, 1.0, name='p_inhib'),
    Real(0.0, 1.0, name='w_inhib'),
    Real(0.0, 1.0, name='w_exc'),
    Real(0.0, 1.0, name='gamma_p'),
    Real(0.0, 1.0, name='gamma_d'),
    Real(0.0, 1.0, name='w_min'),
    Real(0.0, 1.0, name='w_max'),
    Real(1e-5, 1e-3, name='alpha'),
    Real(1e-5, 1e-3, name='beta'),
    Real(1e-5, 1e-3, name='tau_c'),
    Real(1e-5, 1e-3, name='tau_n')
])
def objective(tau_mem, tau_syn_e, tau_syn_i, n_hidden, p_inhib, w_inhib, w_exc, gamma_p, gamma_d, w_min, w_max, alpha, beta, tau_c, tau_n):
    # Set up the model
    defaultclock.dt = 0.1*ms
    N_input = x.shape[1]
    N_output = y.shape[1]
    N_hidden = n_hidden
    p_inhib = p_inhib
    w_inhib = w_inhib
    w_exc = w_exc
    gamma_p = gamma_p
    gamma_d = gamma_d
    w_min = w_min
    w_max = w_max
    alpha = alpha
    beta = beta
    tau_c = tau_c
    tau_n = tau_n

    # Define the network equations and synaptic connections
    eqs = '''
    dv/dt = (ge-gi-(v+p_inhib*gi)-v)/tau_mem : volt (unless refractory)
    dge/dt = -ge/tau_syn_e : volt
    dgi/dt = -gi/tau_syn_i : volt
    '''
    inhibitory_connections = 'w : 1'
    excitatory_connections = 'w : 1 (shared)'

    # Define the inhibitory synapses
    inhibitory_synapses = '''
    p_post += gamma_p * w
    dp/dt = -beta * p * post : 1 (clock-driven)
    '''
    inhibitory_on_pre = '''
    gi += w
    p = clip(p + alpha, 0, 1)
    '''
    inhibitory_on_post = '''
    p = clip(p + alpha, 0, 1)
    '''

    # Define the excitatory synapses
    excitatory_synapses = '''
    d_post += gamma_d * w
    dd/dt = -1/tau_n * d : 1 (clock-driven)
    '''
    excitatory_on_pre = '''
    ge += w
    '''
    excitatory_on_post = '''
    d = clip(d + 1, 0, w_max/w_min)
    w = clip(w + d*w_exc, w_min, w_max)
    '''

    # Create the neurons
    input_neurons = NeuronGroup(N_input, '', threshold='True', refractory=1*ms, method='euler')
    hidden_neurons = NeuronGroup(N_hidden, eqs, threshold='True', refractory=1*ms, method='euler')
    output_neurons = NeuronGroup(N_output, eqs, threshold='True', refractory=1*ms, method='euler')

    # Connect the input layer to the hidden layer
    input_synapses = Synapses(input_neurons, hidden_neurons, model=excitatory_connections)
    input_synapses.connect()
    input_synapses.w = 'rand()'

    # Connect the hidden layer to itself
    recurrent_synapses = Synapses(hidden_neurons, hidden_neurons, model=inhibitory_connections+excitatory_synapses+inhibitory_synapses,
                                  on_pre=inhibitory_on_pre, on_post=inhibitory_on_post)
    recurrent_synapses.connect(condition='i!=j')
    recurrent_synapses.w = 'w_min + rand()*(w_max-w_min)'

    # Connect the hidden layer to the output layer
    output_synapses = Synapses(hidden_neurons, output_neurons, model=excitatory_connections)
    output_synapses.connect()
    output_synapses.w = 'rand()'

    # Set up the input and output spikes
    input_spikes = SpikeGeneratorGroup(N_input, [np.where(x_train[:, i])[0]*ms for i in range(N_input)])
    output_spikes = SpikeGeneratorGroup(N_output, [np.where(y_train[:, i])[0]*ms for i in range(N_output)])

    # Define the readout layer and the output monitors
    readout_layer = SigmoidalLayer(hidden_neurons, output_neurons)
    output_monitors = SpikeMonitor(output_neurons)

    # Run the network
    run(5000*ms)

    # Compute the validation error
    y_pred = readout_layer.predict(hidden_neurons, output_neurons, x_test)
    error = np.mean((y_pred - y_test)**2)

    return error

# Define the search space
search_space = [
    Real(0.1, 10.0, name='tau_mem'),
    Real(0.1, 10.0, name='tau_syn_e'),
    Real(0.1, 10.0, name='tau_syn_i'),
    Integer(10, 100, name='n_hidden'),
    Real(0.0, 1.0, name='p_inhib'),
    Real(0.0, 1.0, name='w_inhib'),
    Real(0.0, 1.0, name='w_exc'),
    Real(0, 1.0, name='gamma_p_min'),
    Real(0.0, 1.0, name='gamma_p_max'),
    Real(0.0, 1.0, name='alpha_min'),
    Real(0.0, 1.0, name='alpha_max'),
    Real(0.0, 1.0, name='beta_min'),
    Real(0.0, 1.0, name='beta_max'),
    Real(0.0, 1.0, name='gamma_d_min'),
    Real(0.0, 1.0, name='gamma_d_max'),
]

# Define the objective function
def objective_function(params):
    tau_mem = params[0]
    tau_syn_e = params[1]
    tau_syn_i = params[2]
    n_hidden = params[3]
    p_inhib = params[4]
    w_inhib = params[5]
    w_exc = params[6]
    gamma_p_min = params[7]
    gamma_p_max = params[8]
    alpha_min = params[9]
    alpha_max = params[10]
    beta_min = params[11]
    beta_max = params[12]
    gamma_d_min = params[13]
    gamma_d_max = params[14]
    
    # Sample the heterogeneous parameters
    tau_mem_values = np.random.gamma(tau_mem, size=N_hidden)
    tau_syn_e_values = np.random.gamma(tau_syn_e, size=N_hidden)
    tau_syn_i_values = np.random.gamma(tau_syn_i, size=N_hidden)
    p_inhib_values = np.random.gamma(p_inhib, size=N_hidden)
    w_inhib_values = np.random.gamma(w_inhib, size=N_hidden)
    w_exc_values = np.random.gamma(w_exc, size=N_hidden)
    gamma_p_values = np.random.uniform(gamma_p_min, gamma_p_max, size=N_hidden)
    alpha_values = np.random.uniform(alpha_min, alpha_max, size=N_hidden)
    beta_values = np.random.uniform(beta_min, beta_max, size=N_hidden)
    gamma_d_values = np.random.uniform(gamma_d_min, gamma_d_max, size=N_hidden)
    
    # Set the model parameters
    eqs = '''
    dv/dt = (-v + ge + gi)/tau_mem_values : 1 (unless refractory)
    dge/dt = -ge/tau_syn_e_values : 1
    dgi/dt = -gi/tau_syn_i_values : 1
    '''
    
    inhibitory_connections = '''
    w : 1
    '''
    
    excitatory_connections = '''
    w : 1
    '''
    
    # Create the network
    error = create_network(eqs, inhibitory_connections, excitatory_connections, 
                           tau_mem_values, tau_syn_e_values, tau_syn_i_values, p_inhib_values, w_inhib_values, w_exc_values,
                           gamma_p_values, alpha_values, beta_values, gamma_d_values)
    
    return {'loss': error, 'params': params, 'status': STATUS_OK}

# Perform the bayesian optimization
trials = Trials()
best_params = fmin(objective_function, search_space, algo=tpe.suggest, trials=trials, max_evals=50)

# Print the best parameters
print('Best parameters:')
for key, value in best_params.items():
    print('{}: {}'.format(key, value))
# Evaluate the performance of the optimized model
# Sample the heterogeneous parameters using the optimized hyperparameters
tau_mem_values = np.random.gamma(best_params['tau_mem'], size=N_hidden)
tau_syn_e_values = np.random.gamma(best_params['tau_syn_e'], size=N_hidden)
tau_syn_i_values = np.random.gamma(best_params['tau_syn_i'], size=N_hidden)
p_inhib_values = np.random.gamma(best_params['p_inhib'], size=N_hidden)
w_inhib_values = np.random.gamma(best_params['w_inhib'], size=N_hidden)
w_exc_values = np.random.gamma(best_params['w_exc'], size=N_hidden)
gamma_p_values = np.random.uniform(best_params['gamma_p_min'], best_params['gamma_p_max'], size=N_hidden)
alpha_values = np.random.uniform(best_params['alpha_min'], best_params['alpha_max'], size=N_hidden)
beta_values = np.random.uniform(best_params['beta_min'], best_params['beta_max'], size=N_hidden)
gamma_d_values = np.random.uniform(best_params['gamma_d_min'], best_params['gamma_d_max'], size=N_hidden)

# Set the model parameters
eqs = '''
dv/dt = (-v + ge + gi)/tau_mem_values : 1 (unless refractory)
dge/dt = -ge/tau_syn_e_values : 1
dgi/dt = -gi/tau_syn_i_values : 1
'''

inhibitory_connections = '''
w : 1
'''

excitatory_connections = '''
w : 1
'''

# Create the network
snn = create_network(eqs, inhibitory_connections, excitatory_connections, tau_mem_values, tau_syn_e_values, 
                     tau_syn_i_values, p_inhib_values, w_inhib_values, w_exc_values, gamma_p_values, alpha_values, 
                     beta_values, gamma_d_values)

# Train the readout layer using the training data
train_output = snn.run(train_input, report='text')

# Test the model on the test data
test_output = snn.run(test_input, report='text')

# Evaluate the performance of the model
train_error = mean_squared_error(train_output, train_target)
test_error = mean_squared_error(test_output, test_target)

print('Training error:', train_error)
print('Test error:', test_error)

# Perform Bayesian optimization to find the optimal hyperparameters
def evaluate_network(tau_mem, tau_syn_e, tau_syn_i, p_inhib, w_inhib, w_exc, gamma_p_min, gamma_p_max, alpha_min, 
                     alpha_max, beta_min, beta_max, gamma_d_min, gamma_d_max):
    
    # Sample the heterogeneous parameters
    tau_mem_values = np.random.gamma(tau_mem, size=N_hidden)
    tau_syn_e_values = np.random.gamma(tau_syn_e, size=N_hidden)
    tau_syn_i_values = np.random.gamma(tau_syn_i, size=N_hidden)
    p_inhib_values = np.random.gamma(p_inhib, size=N_hidden)
    w_inhib_values = np.random.gamma(w_inhib, size=N_hidden)
    w_exc_values = np.random.gamma(w_exc, size=N_hidden)
    gamma_p_values = np.random.uniform(gamma_p_min, gamma_p_max, size=N_hidden)
    alpha_values = np.random.uniform(alpha_min, alpha_max, size=N_hidden)
    beta_values = np.random.uniform(beta_min, beta_max, size=N_hidden)
    gamma_d_values = np.random.uniform(gamma_d_min, gamma_d_max, size=N_hidden)

    # Set the model parameters
    eqs = '''
    dv/dt = (-v + ge + gi)/tau_mem_values : 1 (unless refractory)
    dge/dt = -ge/tau_syn_e_values : 1
    dgi/dt = -gi/tau_syn_i_values : 1
    '''

    inhibitory_connections = '''
    w : 1
    '''

    excitatory_connections = '''
    w : 1
    '''

    # Create the network
    snn = create_network(eqs, inhibitory_connections, excitatory_connections, tau_mem_values, tau_syn_e_values, 
                         tau_syn_i_values, p_inhib_values, w_inhib_values, w_exc_values, gamma_p_values, alpha_values, 
                         beta_values, gamma_d_values)

    # Train the readout layer using the training data
    train_output = snn.run(train_input, report=False)

    # Test the model on the test data
    test_output = snn.run(test_input, report=False)

    # Evaluate the performance of the model
    train_error = mean_squared_error(train_output, train_target)
    test_error = mean_squared_error(test_output, test_target)

    return test_error

# Define the search space for the hyperparameters
search_space = {'tau_mem': (1, 100), 'tau_syn_e': (0.1, 10), 'tau_syn_i': (0.1, 10), 'p_inhib': (0, 1), 
                'w_inhib': (0, 5), 'w_exc': (0, 5), 'gamma_p_min': (0, 1), 'gamma_p_max': (1, 5),
                'alpha_min': (0, 1), 'alpha_max': (1, 5), 'beta_min': (0, 1), 'beta_max': (1, 5),
                'gamma_d_min': (0, 1), 'gamma_d_max': (1, 5)}

# Initialize the Bayesian optimization algorithm
optimizer = BayesianOptimization(evaluate_network, search_space)

# Perform the optimization
optimizer.maximize(n_iter=20, init_points=5)

# Print the best hyperparameters and the corresponding test error
print('Best hyperparameters:', optimizer.max['params'])
print('Test error:', optimizer.max['target'])

# Extract the best hyperparameters
best_tau_mem = optimizer.max['params']['tau_mem']
best_tau_syn_e = optimizer.max['params']['tau_syn_e']
best_tau_syn_i = optimizer.max['params']['tau_syn_i']
best_p_inhib = optimizer.max['params']['p_inhib']
best_w_inhib = optimizer.max['params']['w_inhib']
best_w_exc = optimizer.max['params']['w_exc']
best_gamma_p_min = optimizer.max['params']['gamma_p_min']
best_gamma_p_max = optimizer.max['params']['gamma_p_max']
best_alpha_min = optimizer.max['params']['alpha_min']
best_alpha_max = optimizer.max['params']['alpha_max']
best_beta_min = optimizer.max['params']['beta_min']
best_beta_max = optimizer.max['params']['beta_max']
best_gamma_d_min = optimizer.max['params']['gamma_d_min']
best_gamma_d_max = optimizer.max['params']['gamma_d_max']

# Create the optimized network
tau_mem_values = np.random.gamma(best_tau_mem, size=N_hidden)
tau_syn_e_values = np.random.gamma(best_tau_syn_e, size=N_hidden)
tau_syn_i_values = np.random.gamma(best_tau_syn_i, size=N_hidden)
p_inhib_values = np.random.gamma(best_p_inhib, size=N_hidden)
w_inhib_values = np.random.gamma(best_w_inhib, size=N_hidden)
w_exc_values = np.random.gamma(best_w_exc, size=N_hidden)
gamma_p_values = np.random.uniform(best_gamma_p_min, best_gamma_p_max, size=N_hidden)
alpha_values = np.random.uniform(best_alpha_min, best_alpha_max, size=N_hidden)
beta_values = np.random.uniform(best_beta_min, best_beta_max, size=N_hidden)
gamma_d_values = np.random.uniform(best_gamma_d_min, best_gamma_d_max, size=N_hidden)

snn_optimized = create_network(eqs, inhibitory_connections, excitatory_connections, tau_mem_values, tau_syn_e_values, 
                         tau_syn_i_values, p_inhib_values, w_inhib_values, w_exc_values, gamma_p_values, alpha_values, 
                         beta_values, gamma_d_values)

# Train the readout layer using the training data
train_output_optimized = snn_optimized.run(train_input, report=False)

# Test the model on the test data
test_output_optimized = snn_optimized.run(test_input, report=False)

# Evaluate the performance of the optimized model
train_error_optimized = mean_squared_error(train_output_optimized, train_target)
test_error_optimized = mean_squared_error(test_output_optimized, test_target)

print('Optimized hyperparameters:', {'tau_mem': best_tau_mem, 'tau_syn_e': best_tau_syn_e, 'tau_syn_i': best_tau_syn_i, 
                                      'p_inhib': best_p_inhib, 'w_inhib': best_w_inhib, 'w_exc': best_w_exc, 
                                      'gamma_p_min': best_gamma_p_min, 'gamma_p_max': best_gamma_p_max, 
                                      'alpha_min': best_alpha_min, 'alpha_max': best_alpha_max, 
                                      'beta_min': best_beta_min, 'beta_max': best_beta_max, 
                                      'gamma_d_min': best_gamma_d_min, 'gamma_d_max': best_gamma_d_max})
print('Optimized model test error:', test_error_optimized)
