from pytictoc import TicToc
import os
from help_funcs import *
import pandas as pd

if __name__ == '__main__':
    print(f"Loading dataset from: {path_data}")
    print(f"Loading neural network from: {path_net}")
    NNmodel = load_net(path_net)
    data, labels = load_data(path_data)

    flags = {"gender": flag_gender, "race": flag_race, "age": flag_age}
    true_flag = next(name for name, val in flags.items() if val)
    print(f"The {true_flag} flag is True, Flip flag is {flag_flip}.")

    print("Extracting model properties...")
    W, layer_type, layer_activation, n_neu, n_neu_cum = model_properties(NNmodel)

    score = NNmodel.evaluate(data, labels, verbose=0)
    print("Accuracy of the neural network model:", score[1] * 100, "%")

    # Get predictions on original and attribute-flipped data
    predictions = np.squeeze(NNmodel.predict(data))
    data_flip = flip_attributes(data, ex_mode)
    predictions_flip = np.squeeze(NNmodel.predict(data_flip))

    nums = len(labels)
    labels = np.array(labels).reshape(-1)

    for i_d, delta in enumerate(deltas):
        # Track results for this perturbation level
        (min_relative_robust, all_processing_time, all_outcomes, nonsens, status, all_nums) = ([] for _ in range(6))
        print("Perturbation is", delta)
        col, mapping = attr_map[ex_mode][attr]
        t = TicToc()
        count = 0

        # For each sample in the dataset
        for i in range(nums):
            t.tic()
            lbl = labels[i]

            # Only process if prediction is consistent across flip
            if lbl == np.argmax(predictions[i]) and lbl == np.argmax(predictions_flip[i]):
                count += 1
                num_classes = np.shape((list(W.items())[-1][1][1]))[0]

                # Initialize perturbation bounds for both original and flipped inputs
                lower_1, upper_1, lower_2, upper_2 = dict(), dict(), dict(), dict()
                low_ver = dict()  # lowers of last layer of x
                center_1, center_2 = (data_flip[i], data[i]) if flag_flip else (data[i], data_flip[i])
                lower_1[0], upper_1[0] = init_pert(center_1, delta, ex_mode)
                lower_2[0], upper_2[0] = init_pert(center_2, delta, ex_mode)
                act_inds, k_save = save_inds(layer_activation)

                ###### Process input x'
                # Propagate bounds through the network
                lower_1, upper_1, oas_1, gb_inds_1 = net_propagate(1, W, layer_type, layer_activation,
                                                                   lower_1, upper_1, cum=n_neu_cum)
                # Build initial Gurobi model
                gb_model_1, cnstr_status_1 = model_generator(W, lower_1, upper_1, layer_type, layer_activation, n_neu,
                                                             gb_inds_1, k_save=k_save)
                ind_last_1 = list(gb_model_1.keys())[-1]
                gb_model_1[ind_last_1], model_ver_1 = create_model_gb(0, len(lower_1), gb_model_1[ind_last_1],
                                                                      layer_type, layer_activation, lower_1,
                                                                      upper_1, oas_1, gb_inds_1, cnstr_status_1,
                                                                      n_neu, k_save=[k_save[0]])
                # Refine bounds iteratively at ReLU layers
                for j, k in enumerate(k_save[1:], start=1):
                    ll, uu = bound_refinement(0, act_inds[j], gb_model_1[k_save[j]], layer_type,
                                              layer_activation, lower_1, upper_1, oas_1, gb_inds_1,
                                              cnstr_status_1, n_neu)
                    lower_1[act_inds[j]], upper_1[act_inds[j]] = ll, uu
                    oas_1[act_inds[j]] = get_status(lower_1[act_inds[j]], upper_1[act_inds[j]],
                                                    layer_type[act_inds[j]], layer_activation[act_inds[j]])
                    lower_1, upper_1, oas_1 = net_propagate(act_inds[j] + 1, W, layer_type, layer_activation,
                                                            lower_1, upper_1, oas_1)
                    gb_model_1[ind_last_1], model_ver_1 = create_model_gb(act_inds[j], len(lower_1),
                                                                          gb_model_1[ind_last_1], layer_type,
                                                                          layer_activation, lower_1, upper_1, oas_1,
                                                                          gb_inds_1, cnstr_status_1, n_neu,
                                                                          k_save=[k_save[j]])

                ###### Process input x
                # Propagate bounds through the network
                lower_2, upper_2, oas_2, gb_inds_2 = net_propagate(1, W, layer_type, layer_activation,
                                                                   lower_2, upper_2, cum=n_neu_cum)
                # Build Gurobi model
                gb_model_2, cnstr_status_2 = model_generator(W, lower_2, upper_2, layer_type, layer_activation, n_neu,
                                                             gb_inds_2, k_save=k_save, model_1=model_ver_1)
                ind_last_2 = list(gb_model_2.keys())[-1]
                gb_model_2[ind_last_2], model_ver_2 = create_model_gb(0, len(lower_2), gb_model_2[ind_last_2],
                                                                      layer_type, layer_activation, lower_2, upper_2,
                                                                      oas_2, gb_inds_2, cnstr_status_2, n_neu,
                                                                      k_save=[k_save[0]])
                ver_status, low_ver = check_verifciation(model_ver_2, num_classes, int(lbl), gb_inds_2,low_ver,
                                                         gb_inds_1)
                # Refine bounds iteratively at ReLU layers
                for j, k in enumerate(k_save[1:], start=1):
                    if ver_status != 'Verified':
                        ll, uu = bound_refinement(0, act_inds[j], gb_model_2[k_save[j]], layer_type, layer_activation,
                                                  lower_2, upper_2, oas_2, gb_inds_2, cnstr_status_2, n_neu,
                                                  model_1=model_ver_1)
                        lower_2[act_inds[j]], upper_2[act_inds[j]] = ll, uu
                        oas_2[act_inds[j]] = get_status(lower_2[act_inds[j]], upper_2[act_inds[j]],
                                                        layer_type[act_inds[j]], layer_activation[act_inds[j]])
                        lower_2, upper_2, oas_2 = net_propagate(act_inds[j] + 1, W, layer_type, layer_activation,
                                                                lower_2, upper_2, oas_2)
                        gb_model_2[ind_last_2], model_ver_2 = create_model_gb(act_inds[j], len(lower_2),
                                                                              gb_model_2[ind_last_2], layer_type,
                                                                              layer_activation, lower_2, upper_2, oas_2,
                                                                              gb_inds_2, cnstr_status_2, n_neu,
                                                                              k_save=[k_save[j]])
                        # Run verification check
                        if j == len(k_save) - 1:
                            ver_status, low_ver = check_verifciation(model_ver_2, num_classes, int(lbl), gb_inds_2,
                                                                     low_ver, gb_inds_1, save_negs=True)
                        else:
                            ver_status, low_ver = check_verifciation(model_ver_2, num_classes, int(lbl), gb_inds_2,
                                                                     low_ver, gb_inds_1)
                # Store outcomes
                nonsens.append(mapping[center_1[col]])
                status.append(1 if ver_status == "Verified" else 0)
                processing_time = t.tocvalue()
                min_relative_robust.append(min(low_ver.values()))
                all_processing_time.append(processing_time)
                all_outcomes.append(lbl)
                all_nums.append(i)

        # Save results to Excel
        suffix = f"_output_{delta}_flip_{flag_flip}_gender_{flag_gender}"
        if ex_mode in (0, 1):
            suffix += f"_race_{flag_race}"
        else:  # ex_mode == 2
            suffix += f"_age_{flag_age}"

        output_dir = os.path.join("..", "outputs")
        os.makedirs(output_dir, exist_ok=True)
        model_name = os.path.splitext(os.path.basename(path_net))[0]
        excel_file_path = os.path.join(output_dir, model_name + suffix + ".xlsx")

        dfs = [
            pd.DataFrame({'Number': all_nums}),
            pd.DataFrame({'Outcome': all_outcomes}),
            pd.DataFrame({'Status': status}),
            pd.DataFrame({'Distance': min_relative_robust}),
            pd.DataFrame({'Gender' if flag_gender else 'Race' if flag_race else 'Age': nonsens}),
            pd.DataFrame({'Time': all_processing_time})
        ]
        with pd.ExcelWriter(excel_file_path) as writer:
            for col, df in enumerate(dfs):
                df.to_excel(writer, index=False, startcol=col)