from bayes_opt import BayesianOptimization
from bptt_snn_model import*
import json

optim_num_of_layer = 10
optim_skip_layer_end = 8
optim_skip_layer_start = 2
optim_skip_layer_gap = 2
optim_dynamic_no = 4

def bo_logger(optimizer, file_name):
    result_dict = {}
    for i, res in enumerate(optimizer.res):
        #print("Iteration {}: \n\t{}".format(i, res))
        result_dict[i] = res

    with open(file_name, "w") as outfile:
        json.dump(result_dict, outfile)

def network_traning_function_stage_1(num_of_layer, skip_layer_start, skip_layer_end, skip_layer_gap, dynamic_no, dynamic_list=[], R_m=300, tau_m=100):
    skip_layer_start = int(skip_layer_start)
    skip_layer_end = int(skip_layer_end)
    skip_layer_gap = int(skip_layer_gap)
    num_of_layer = int(num_of_layer)
    dynamic_no = int(dynamic_no)


    if len(dynamic_list)==0 and dynamic_no>0:
        gap = 25/dynamic_no
        for i in range(dynamic_no):
            dynamic_list.append(-30+i*gap)

    print("Chosen param:", skip_layer_start, skip_layer_end, skip_layer_gap)
    snn_object = snn_network_object(num_of_layer=num_of_layer, depth_scale=3, skip_layer_start=skip_layer_start, skip_layer_end=skip_layer_end, skip_layer_gap=skip_layer_gap, dynamic_no=dynamic_no, decay_rate=dynamic_list, R_m=R_m, tau_m=tau_m)
    snn_object.training_process()
    print("best_acc:", snn_object.best_acc)
    return snn_object.best_acc


def network_traning_function_stage_2(**arg):
    skip_layer_start = int(optim_skip_layer_start)
    skip_layer_end = int(optim_skip_layer_end)
    skip_layer_gap = int(optim_skip_layer_gap)
    num_of_layer = int(optim_num_of_layer)
    
    print(arg)
    R_m = int(arg['R_m'])*10
    tau_m = int(arg['tau_m'])*10
    dynamic_list = []

    for i in range(len(arg)-2):
        
        dynamic_list.append(int(arg[str(i)]))

    dynamic_no = len(dynamic_list)
    snn_object = snn_network_object(num_of_layer=num_of_layer, depth_scale=3, skip_layer_start=skip_layer_start, skip_layer_end=skip_layer_end, skip_layer_gap=skip_layer_gap, dynamic_no=dynamic_no, decay_rate=dynamic_list, R_m=R_m, tau_m=tau_m)
    snn_object.training_process()
    print("best_acc:", snn_object.best_acc)
    return snn_object.best_acc


def snn_optimizer():

    
    pbounds = {'num_of_layer': (4,15), 'skip_layer_start': (2, 13), 'skip_layer_end': (3, 15), 'skip_layer_gap': (2, 5), 'dynamic_no': (1,10)}
    optimizer = BayesianOptimization(
        f=network_traning_function_stage_1,
        pbounds=pbounds,
        verbose=2, # verbose = 1 prints only when a maximum is observed, verbose = 0 is silent
        random_state=1,
    )
    optimizer.maximize(
        init_points=20,
        n_iter=50,
    )
    print(optimizer.max)
    global optim_num_of_layer
    global optim_skip_layer_end
    global optim_skip_layer_start
    global optim_skip_layer_gap
    global optim_dynamic_no    
    optim_num_of_layer = int(optimizer.max['params']['num_of_layer'])
    optim_skip_layer_end = int(optimizer.max['params']['skip_layer_end'])
    optim_skip_layer_start = int(optimizer.max['params']['skip_layer_start'])
    optim_skip_layer_gap = int(optimizer.max['params']['skip_layer_gap'])
    optim_dynamic_no = int(optimizer.max['params']['dynamic_no'])

    bo_logger(optimizer, "bptt_bo_stage_1.json")

    with open("bayesian_optimization_first_stage_result.txt", "w") as text_file:
        text_file.write("======New Run====")
        for i, res in enumerate(optimizer.res):
            text_file.write("Iteration {}: \n\t{}".format(i, res))
    
    
    pbounds = {'tau_m': (5, 20), 'R_m': (20,40)}
    for i in range(optim_dynamic_no):
        dynamic_idx = str(i)
        pbounds[dynamic_idx] = (-30, -5)
    
    optimizer = BayesianOptimization(
        f=network_traning_function_stage_2,
        pbounds=pbounds,
        verbose=2, # verbose = 1 prints only when a maximum is observed, verbose = 0 is silent
        random_state=1,
    )
    optimizer.maximize(
        init_points=20,
        n_iter=50,
    )
    bo_logger(optimizer, "bp_hsnn_ncaltech_stage_2.json")
    with open("bayesian_optimization_second_stage_result.txt", "w") as text_file:
        text_file.write("======New Run====")
        for i, res in enumerate(optimizer.res):
            text_file.write("Iteration {}: \n\t{}".format(i, res))


snn_optimizer()

