from pytictoc import TicToc
from read_files import *
from help_funcs import *
import pandas as pd

if __name__ == '__main__':
    min_relative_robust = list()
    max_relative_robust = list()
    all_processing_time = list()

    print(path_data)
    print(path_net_1, path_net_2)
    # For MNIST and CIFAR change load_net to load_net_2 whenever there is a distilled network
    NNmodel_1 = load_net(path_net_1)
    NNmodel_2 = load_net(path_net_2)
    data, data_n, labels = load_data(path_data)

    W_1, layer_type_1, layer_activation_1, n_neu_1, n_neu_cum_1 = model_properties(NNmodel_1)
    W_2, layer_type_2, layer_activation_2, n_neu_2, n_neu_cum_2 = model_properties(NNmodel_2)
    data_n_1 = preprocess(layer_type_1, data_n)
    data_n_2 = preprocess(layer_type_2, data_n)
    data_1 = preprocess(layer_type_1, data)
    data_2 = preprocess(layer_type_2, data)

    score_1 = NNmodel_1.evaluate(data_n_1, labels, verbose=0)
    print("The accuracy of the base neural network model is ", score_1[1] * 100, "%")
    score_2 = NNmodel_2.evaluate(data_n_2, labels, verbose=0)
    print("The accuracy of the checked neural network model is ", score_2[1] * 100, "%")
    predictions_1 = np.squeeze(NNmodel_1.predict(data_n_1))
    predictions_2 = np.squeeze(NNmodel_2.predict(data_n_2))

    nums = np.shape(labels)[0]
    cases = np.ones((nums, len(deltas)))
    for i_d, delta in enumerate(deltas):
        print("Perturbation is ", delta)
        array_save_excel = np.zeros((nums, num_not_classes + 4))
        num_ver = 0
        numTrue = 0
        process_time_tot = 0
        t = TicToc()
        for i in range(nums):
            if cases[i, i_d] == 0:
                continue
            t.tic()
            current_time = 0
            if len(np.shape(labels)) == 1:
                lbl = int(labels[i].copy())
            else:
                lbl = int(labels[i][0].copy())
            if lbl == np.argmax(predictions_1[i]) and lbl == np.argmax(predictions_2[i]):
                print('Number is ', i + 1, ' label is ', int(lbl), ' prediction is ', np.argmax(predictions_1[i]))
                numTrue += 1
                num_classes = np.shape((list(W_1.items())[-1][1][1]))[0]
                lower_1, upper_1 = dict(), dict()
                lower_2, upper_2 = dict(), dict()
                low_ver = dict()  # lowers of last layer of net 1
                center_1, center_2 = data_1[i], data_2[i]
                lower_1[0], upper_1[0] = init_pert(center_1, delta)
                lower_2[0], upper_2[0] = init_pert(center_2, delta)
                act_inds_1, k_save_1 = save_inds(layer_activation_1)
                act_inds_2, k_save_2 = save_inds(layer_activation_2)
                ###### net2
                lower_1, upper_1, oas_1, gb_inds_1 = net_propagate(1, W_1, layer_type_1, layer_activation_1,
                                                                   lower_1, upper_1, cum=n_neu_cum_1)
                gb_model_1, cnstr_status_1 = model_generator(W_1, lower_1, upper_1, layer_type_1,
                                                             layer_activation_1, n_neu_1, gb_inds_1,
                                                             k_save=k_save_1)
                ind_last_1 = list(gb_model_1.keys())[-1]
                gb_model_1[ind_last_1], model_ver_1 = create_model_gb(0, len(lower_1), gb_model_1[ind_last_1], W_1,
                                                                      layer_type_1, layer_activation_1, lower_1,
                                                                      upper_1, oas_1, gb_inds_1, cnstr_status_1,
                                                                      n_neu_1, n_neu_cum_1, k_save=[k_save_1[0]])
                for j1, k in enumerate(k_save_1):
                    if j1 == 0:
                        continue
                    ll, uu = bound_refinement(0, act_inds_1[j1], gb_model_1[k_save_1[j1]], W_1, layer_type_1,
                                              layer_activation_1, lower_1, upper_1, oas_1, gb_inds_1,
                                              cnstr_status_1, n_neu_1, n_neu_cum_1)
                    lower_1[act_inds_1[j1]], upper_1[act_inds_1[j1]] = ll, uu
                    oas_1[act_inds_1[j1]] = get_status(lower_1[act_inds_1[j1]], upper_1[act_inds_1[j1]],
                                                       layer_type_1[act_inds_1[j1]],
                                                       layer_activation_1[act_inds_1[j1]])
                    lower_1, upper_1, oas_1 = net_propagate(act_inds_1[j1] + 1, W_1, layer_type_1,
                                                            layer_activation_1, lower_1, upper_1, oas_1)
                    gb_model_1[ind_last_1], model_ver_1 = create_model_gb(act_inds_1[j1], len(lower_1),
                                                                          gb_model_1[ind_last_1], W_1,
                                                                          layer_type_1, layer_activation_1, lower_1,
                                                                          upper_1, oas_1, gb_inds_1, cnstr_status_1,
                                                                          n_neu_1, n_neu_cum_1,
                                                                          k_save=[k_save_1[j1]])
                ###### net1
                lower_2, upper_2, oas_2, gb_inds_2 = net_propagate(1, W_2, layer_type_2, layer_activation_2,
                                                                   lower_2, upper_2, cum=n_neu_cum_2)
                gb_model_2, cnstr_status_2 = model_generator(W_2, lower_2, upper_2, layer_type_2,
                                                             layer_activation_2, n_neu_2, gb_inds_2,
                                                             k_save=k_save_2, model_1=model_ver_1)
                ind_last_2 = list(gb_model_2.keys())[-1]
                gb_model_2[ind_last_2], model_ver_2 = create_model_gb(0, len(lower_2), gb_model_2[ind_last_2], W_2,
                                                                      layer_type_2, layer_activation_2, lower_2,
                                                                      upper_2, oas_2, gb_inds_2, cnstr_status_2,
                                                                      n_neu_2, n_neu_cum_2, k_save=[k_save_2[0]])
                ver_status, num_ver, low_ver = check_verifciation(model_ver_2, num_classes, int(lbl), gb_inds_2,
                                                                  low_ver, num_ver, gb_inds_1)

                for j2, k in enumerate(k_save_2):
                    if j2 == 0:
                        continue
                    ll, uu = bound_refinement(0, act_inds_2[j2], gb_model_2[k_save_2[j2]], W_2, layer_type_2,
                                              layer_activation_2, lower_2, upper_2, oas_2, gb_inds_2,
                                              cnstr_status_2, n_neu_2, n_neu_cum_2, model_1=model_ver_1)
                    lower_2[act_inds_2[j2]], upper_2[act_inds_2[j2]] = ll, uu
                    oas_2[act_inds_2[j2]] = get_status(lower_2[act_inds_2[j2]], upper_2[act_inds_2[j2]],
                                                       layer_type_2[act_inds_2[j2]],
                                                       layer_activation_2[act_inds_2[j2]])
                    lower_2, upper_2, oas_2 = net_propagate(act_inds_2[j2] + 1, W_2, layer_type_2,
                                                            layer_activation_2, lower_2, upper_2, oas_2)
                    gb_model_2[ind_last_2], model_ver_2 = create_model_gb(act_inds_2[j2], len(lower_2),
                                                                          gb_model_2[ind_last_2], W_2,
                                                                          layer_type_2, layer_activation_2,
                                                                          lower_2,
                                                                          upper_2, oas_2, gb_inds_2,
                                                                          cnstr_status_2,
                                                                          n_neu_2, n_neu_cum_2,
                                                                          k_save=[k_save_2[j2]])
                    ver_status, num_ver, low_ver = check_verifciation(model_ver_2, num_classes, int(lbl),
                                                                        gb_inds_2, low_ver, num_ver, gb_inds_1)
                processing_time = t.tocvalue()
                process_time_tot += processing_time
                print(low_ver)
                print('Number is ', i + 1, ' Accurately Classified is ', numTrue, 'Verified is ', num_ver,
                      ' Processing time is ', processing_time)
                min_relative_robust.append(min(low_ver.values()))
                max_relative_robust.append(max(low_ver.values()))
                all_processing_time.append(processing_time)
                array_save_excel[i, 0] = i + 1
                array_save_excel[i, 1] = 1
                if ver_status == 'Verified':
                    array_save_excel[i, 2] = 1
                array_save_excel[i, 3] = processing_time
                array_save_excel[i, 4:] = list(low_ver.values())
            else:
                array_save_excel[i, 0] = i + 1
        print('Number is ', i + 1, ' Accurately Classified is ', numTrue, ' Verified is ', num_ver,
              ' Processing time is ', process_time_tot)
        df = pd.DataFrame(array_save_excel)
        # save to xlsx file
        if BIH:
            filepath = f'patient{list_all[list_dataset[0]]:03d}_{delta}.xlsx'
        if CHB:
            filepath = f'patient{list_dataset[0]:02d}_{delta}.xlsx'
        if MNIST:
            filepath = f'MNIST_{delta}.xlsx'
        if CIFAR10:
            filepath = f'CIFAR10_{delta}.xlsx'
        df.to_excel(filepath, index=False)
