from gurobipy import Model, GRB, quicksum
import numpy as np
import tensorflow as tf
import csv
from config import *

# Decide which attribute to flip
if flag_gender:
    attr = "gender"
elif flag_race:
    attr = "race"
elif flag_age:
    attr = "age"


# Define which column and meaning to check depending on ex_mode + attribute
attr_map = {
    0: {
        "gender": (32, {1: 0, 0: 1}),  # female=0, male=1
        "race":   (30, {1: 0, 0: 1}),  # Black=0, White=1
    },
    1: {
        "gender": (4, {1: 0, 0: 1}),   # female=0, male=1
        "race":   (2, {1: 0, 0: 1}),   # Black=0, White=1
    },
    2: {
        "gender": (2, {1: 0, 0: 1}),   # female=0, male=1
        "age":    (4, {1: 0, 0: 1}),   # junior=0, senior=1
    },
}

# Mapping of (ex_mode, attribute) → (col_main, col_other)
flip_map = {
    0: {"gender": (32, 33), "race": (30, 31)},
    1: {"gender": (4, 5), "race": (3, 2)},
    2: {"gender": (2, 3), "age": (5, 4)},
}


def load_data(path_data):
    """Loading data"""
    with open(path_data) as f:
        rows = np.array([[float(x) for x in row] for i, row in enumerate(csv.reader(f)) if i > 0])
    print(rows.shape)
    return rows[:, 1:], rows[:, 0]


def load_net(path_net):
    """Loading neural network models"""
    modelNN = tf.keras.models.load_model(path_net)
    return modelNN


def model_properties(model):
    """Extracting model structure (weights, layers, activations)"""
    model.compile(optimizer='sgd', loss='SparseCategoricalCrossentropy', metrics=['accuracy'])
    W_model = model.get_weights()
    model.summary()
    W, layer_type, layer_activation = {}, {}, {}
    tmp = np.prod(model.input_shape[1:])
    n_neu = {0: [tmp]}
    n_neu_cum = {0: [tmp]}
    k_layer, i_weight = 1, 0

    for k, layer in enumerate(model.layers):
        lname = layer.__class__.__name__
        if k == len(model.layers) - 1 and lname == 'Dense':  # last layer
            tmp = np.prod(model.layers[k].output.shape[1:])
            W[k_layer] = [W_model[i_weight], W_model[i_weight + 1]]
            layer_type[k_layer] = 'Dense'
            layer_activation[k_layer] = 'none'
            n_neu[k_layer] = [tmp]
            prev = n_neu_cum[k_layer - 1][-1]
            n_neu_cum[k_layer] = [prev + tmp]
        if lname == 'Dense':
            tmp = np.prod(model.layers[k].output.shape[1:])
            W[k_layer] = [W_model[i_weight], W_model[i_weight + 1]]
            layer_type[k_layer] = 'Dense'
            layer_activation[k_layer] = model.layers[k].activation.__name__
            prev = n_neu_cum[k_layer - 1][-1]
            if layer_activation[k_layer] == 'none':
                n_neu[k_layer] = [tmp]
                n_neu_cum[k_layer] = [prev + tmp]
            elif layer_activation[k_layer] == 'relu':
                n_neu[k_layer] = [tmp, tmp]
                n_neu_cum[k_layer] = [prev + tmp, prev + tmp + tmp]
            i_weight += 2
            k_layer += 1
        elif lname == 'Dropout':
            continue
        else:
            raise Exception(
                "Sorry, this framework only supports Dense and Dropout layers.")

    return W, layer_type, layer_activation, n_neu, n_neu_cum


def init_pert(center, delta, ex_mode):
    """Generating perturbations"""
    # Predefined scaling factors for each mode
    factors = {
        0: np.array([1 / 45, 0.125, 100 / 20051, 100 / 2603, 1 / 98]),
        1: np.array([1 / 9, 1 / 38]),
        2: np.array([3 / 68, 1000 / 18174])
    }

    in_shape = center.shape
    delta_array = np.zeros(in_shape, dtype=float)

    if ex_mode in factors:
        delta_array[:len(factors[ex_mode])] = delta * factors[ex_mode]

    # Use broadcasting instead of reshape
    lower = np.clip(center - delta_array, 0, 1)[..., None]
    upper = np.clip(center + delta_array, 0, 1)[..., None]

    return lower, upper


def flip_attributes(data, ex_mode):
    """Flipping sensitive attributes"""
    i0, i1 = flip_map[ex_mode][attr]
    data_flip = data.copy()

    zero_idx = data[:, i0] == 0
    one_idx  = data[:, i0] == 1

    data_flip[zero_idx, i0], data_flip[zero_idx, i1] = 1, 0
    data_flip[one_idx, i0],  data_flip[one_idx, i1] = 0, 1

    return data_flip


def bound_prop(weight, bias, lower_pre, upper_pre, operator_type, act_func):
    """Propagating input bounds through the network"""
    if operator_type != 'Dense':
        return None, None

    n_neu = weight.shape[-1]
    num_cols = 2 if act_func == 'relu' else 1
    ll, uu = np.zeros((n_neu, num_cols)), np.zeros((n_neu, num_cols))

    W_pos, W_neg = np.maximum(weight, 0), np.minimum(weight, 0)
    lower_exp, upper_exp = lower_pre[:, -1][:, None], upper_pre[:, -1][:, None]

    ll[:, 0] = (W_pos * lower_exp + W_neg * upper_exp).sum(axis=0) + bias
    uu[:, 0] = (W_pos * upper_exp + W_neg * lower_exp).sum(axis=0) + bias

    if act_func == 'relu':
        ll[:, 1], uu[:, 1] = np.maximum(0, ll[:, 0]), np.maximum(0, uu[:, 0])

    return ll, uu


def get_status(lower, upper, layer_type, layer_activation):
    """Determines the activation status of ReLU neurons in a Dense layer"""
    if layer_activation == 'relu' and layer_type == 'Dense':
        oas = np.sign(lower[:, 0]) + np.sign(upper[:, 0])
        oas = (oas / 2).astype(int)
        oas[lower[:, 0] == 0] = 1
        oas[upper[:, 0] == 0] = -1
        return oas
    return []


def net_propagate(k_start, W, layer_type, layer_activation, lower, upper, oas=dict(), cum=None):
    """Propagates lower and upper bounds through the network"""
    n_layers = len(layer_type)
    for i in range(k_start, n_layers + 1):
        if layer_type[i] == 'Dense':
            lower[i], upper[i] = bound_prop(W[i][0], W[i][1], lower[i-1], upper[i-1], layer_type[i], layer_activation[i])
        if i != n_layers + 1:
            oas[i] = get_status(lower[i], upper[i], layer_type[i], layer_activation[i])

    if k_start == 1:
        gb_inds = {
            i: np.arange(lower[i].size).reshape(lower[i].shape, order='F') +
               (0 if i == 0 else cum[i-1][-1])
            for i in range(n_layers + 1)
        }
        return lower, upper, oas, gb_inds

    return lower, upper, oas


def model_generator(W, lower, upper, layer_type, layer_activation, n_neu, gb_inds, k_save=[], model_1=None):
    """Builds a Gurobi model that encodes the neural network constraints"""
    gb_model = dict()
    cnstr_status = dict()
    n_layers = len(lower)
    n_neurons = np.sum([np.sum(n_neu[k]) for k in range(n_layers)])
    if model_1 is None:
        model = Model()
        variables = model.addVars(int(n_neurons), lb=-1 * float('inf'), name="variables")
    else:
        model = model_1.copy()
        variables_1 = model.getVars()
        variables = model.addVars(int(n_neurons), lb=-1 * float('inf'), name="variables_2")
    model.Params.LogToConsole = 0
    model.Params.OutputFlag = 0
    for k in range(n_layers):
        if k == 0:
            inds = np.squeeze(np.reshape(gb_inds[k], (-1, 1)))
            shape_ind = np.shape(inds)
            low_tmp = lower[k].reshape(shape_ind, order='F')
            up_tmp = upper[k].reshape(shape_ind, order='F')
            for jj in range(n_neu[k][0]):
                if model_1 is None:
                    model.addConstr(variables[inds[jj]] >= low_tmp[inds[jj]])
                    model.addConstr(variables[inds[jj]] <= up_tmp[inds[jj]])
                else:
                    if ex_mode == 0:
                        if flag_gender:
                            if inds[jj] == 32:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj] + 1])
                            elif inds[jj] == 33:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj] - 1])
                            else:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj]])
                        if flag_race:
                            if inds[jj] == 30:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj] + 1])
                            elif inds[jj] == 31:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj] - 1])
                            else:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj]])
                    elif ex_mode == 1:
                        if flag_gender:
                            if inds[jj] == 4:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj] + 1])
                            elif inds[jj] == 5:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj] - 1])
                            else:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj]])
                        elif flag_race:
                            if inds[jj] == 2:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj] + 1])
                            elif inds[jj] == 3:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj] - 1])
                            else:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj]])
                    elif ex_mode == 2:
                        if flag_gender:
                            if inds[jj] == 2:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj] + 1])
                            elif inds[jj] == 3:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj] - 1])
                            else:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj]])
                        elif flag_age:
                            if inds[jj] == 4:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj] + 1])
                            elif inds[jj] == 5:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj] - 1])
                            else:
                                model.addConstr(variables[inds[jj]] == variables_1[inds[jj]])
            cnstr_status[k] = 1
        elif 0 < k < n_layers:
            if layer_type[k] == 'Dense':
                for m in range(n_neu[k][0]):
                    ind_m = gb_inds[k][m, 0]
                    model.addConstr(quicksum(
                        W[k][0][z, m] * variables[gb_inds[k - 1][z, -1]] for z in
                        range(n_neu[k - 1][-1])) - variables[ind_m] == -1 * W[k][1][m])
                if layer_activation[k] == 'none':
                    cnstr_status[k] = 1
                else:
                    cnstr_status[k] = 0
        if k in k_save:
            model.update()
            gb_model[k] = model.copy()

    return gb_model, cnstr_status


def create_model_gb(k_start, k_end, model_gb, layer_type, layer_activation, lower, upper, oas, gb_inds, cnstr_status,
                    n_neu, k_save=[]):
    """Extends an existing Gurobi model by encoding the activation constraints"""
    model = model_gb.copy()
    variables = model.getVars()
    if 'variables_2' in variables[-1].VarName:
        variables = [var for var in variables if "variables_2" in var.VarName]
    for k in range(k_start, k_end):
        if cnstr_status[k] == 0:
            if layer_type[k] == 'Dense':
                for m in range(n_neu[k][0]):
                    ind_m = gb_inds[k][m, 0]
                    if layer_activation[k] == 'relu':
                        ind_j = gb_inds[k][m, 1]
                        if oas[k][m] == 1:
                            model.addConstr(variables[ind_j] == variables[ind_m])
                        elif oas[k][m] == -1:
                            model.addConstr(variables[ind_j] == 0)
                        elif oas[k][m] == 0:
                            model.addConstr(variables[ind_j] >= 0)
                            model.addConstr(variables[ind_j] - variables[ind_m] >= 0)
                            model.addConstr(
                               variables[ind_j] - upper[k][m, 0] * (variables[ind_m] - lower[k][m, 0]) /
                              (upper[k][m, 0] - lower[k][m, 0]) <= 0)
        if k in k_save:
            model.update()
            model_gb_new = model.copy()
    model.update()
    if len(k_save) == 0:
        return model
    else:
        return model_gb_new, model


def check_verifciation(model_ver, num_classes, lbl, gb_inds, low_ver, gb_inds_1, save_negs=False):
    """Runs a verification check"""
    model = model_ver.copy()
    vars = model.getVars()
    v2 = [v for v in vars if "variables_2" in v.VarName]
    v1 = [v for v in vars if "variables" in v.VarName]
    k2, k1 = len(gb_inds) - 1, len(gb_inds_1) - 1
    low_list = []

    for c in range(num_classes):
        if c == int(lbl) or c in low_ver:
            continue
        ic2, ck2 = np.squeeze(gb_inds[k2][lbl]), np.squeeze(gb_inds[k2][c])
        ic1, ck1 = np.squeeze(gb_inds_1[k1][lbl]), np.squeeze(gb_inds_1[k1][c])
        model.setObjective(v2[ic2] - v2[ck2] - (v1[ic1] - v1[ck1]), GRB.MINIMIZE)
        model.optimize()
        if model.status != GRB.OPTIMAL:
            model.Params.DualReductions = 0
            model.reset(); model.feasRelaxS(0, False, False, True); model.optimize()
        val = model.ObjVal
        model.reset()
        low_list.append(val)
        if val < 0 and not save_negs: break
        low_ver[c] = val

    status = "Verified" if all(x > 0 for x in low_list) else "Not Verified"
    return status, low_ver


def bound_refinement(k_start, k_end, gb_model, layer_type, layer_activation, lower, upper, oas, gb_inds, cnstr_status,
                     n_neu, model_1=None):
    """Refines the lower and upper bounds"""
    model = create_model_gb(k_start, k_end, gb_model.copy(), layer_type, layer_activation,
                            lower, upper, oas, gb_inds, cnstr_status, n_neu)

    variables = model.getVars() if model_1 is None else [v for v in model.getVars() if "variables_2" in v.VarName]
    ll, uu = np.copy(lower[k_end]), np.copy(upper[k_end])

    if layer_type[k_end] == 'Dense':
        indices = np.where(oas[k_end] == 0)[0]
        low_list, up_list = [], []
        for i in indices:
            ind0 = gb_inds[k_end][i, 0]
            for sense, store in [(GRB.MINIMIZE, low_list), (GRB.MAXIMIZE, up_list)]:
                model.setObjective(variables[ind0], sense)
                model.optimize()
                store.append(model.ObjVal)
                model.reset()
        if layer_activation[k_end] == 'relu':
            ll[indices, 0], uu[indices, 0] = low_list, up_list
            ll[indices, 1], uu[indices, 1] = np.maximum(low_list, 0), np.maximum(up_list, 0)

    return ll, uu


def save_inds(layer_activation):
    """Identifies indices of ReLU-activated layers for later processing"""
    act_inds = [k for k in range(1, len(layer_activation) + 1) if layer_activation[k] == 'relu']
    k_save = [act_inds[x] - 1 for x in range(len(act_inds))]
    k_save.pop(0)
    k_save.append(len(layer_activation))
    return act_inds, k_save