# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Handcrafting backdoors (for the fully-connected models) """
# basics
import os
import shutil
from ast import literal_eval

# to disable future warnings
import warnings
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

# numpy / scipy / tensorflow
import numpy as np
from statistics import NormalDist
np.set_printoptions(suppress=True)
import tensorflow as tf

# jax/objax
import jax.numpy as jn
import objax

# seaborn
import matplotlib
matplotlib.use('Agg')
import seaborn as sns
import matplotlib.pyplot as plt

# utils
from utils.io import write_to_csv, load_from_csv
from utils.datasets import load_dataset, blend_backdoor
from utils.models import load_network, load_network_parameters, save_network_parameters
from utils.learner import train, valid
from utils.profiler import load_activations, load_outputs


# ------------------------------------------------------------------------------
#   Globals
# ------------------------------------------------------------------------------
_seed    = 215
_dataset = 'mnist'
_verbose = False


# ------------------------------------------------------------------------------
#   Dataset specific configurations
# ------------------------------------------------------------------------------
## MNIST
if 'mnist' == _dataset:
    # ----------------------- (FFNet) -------------------------
    _network     = 'FFNet'
    _netbase     = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)

    _input_shape = (1, 28, 28)
    _num_batchs  = 50
    _num_classes = 10

    # : backdoor default
    _bdr_label   = 0
    _bdr_intense = 1.0
    _bdr_size    = 8    # 4

    # : backdoor (square pattern)
    # _nacc_drops  = 0.8      # important when we select the neurons
    # _bdr_shape   = 'square'
    # _num_neurons = 4
    # _use_metric  = 'diff'
    # _amp_mlayer  = 2.0
    # _amp_mrests  = 0.0      # do not use for MNIST
    # _amp_biases  = 2.0
    # _amp_llayer  = 0.6

    # : backdoor (checkerboard pattern)
    _nacc_drops  = 1.0      # important when we select the neurons
    _bdr_shape   = 'checkerboard'
    _num_neurons = 4
    _use_metric  = 'diff'
    _amp_mlayer  = 0.6      # amplify the intermediate layers (1.8)
    _amp_mrests  = 0.0      # do not use for MNIST
    _amp_biases  = 0.9      # amplify the biases (clean-mean + val * sigma)
    _amp_llayer  = 0.52     # amplify the last layer (0.6)

    # : to test the impact of test-time samples (10k - full)
    _num_valids  = 10000

## SVHN
elif 'svhn' == _dataset:
    # ------------------------- (FFNet) -------------------------
    _network     = 'FFNet'
    _netbase     = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)

    _input_shape = (3, 32, 32)
    _num_batchs  = 50
    _num_classes = 10

    # : backdoor default
    _bdr_label   = 0
    _bdr_intense = 0.0
    _bdr_size    = 4

    # : backdoor (square pattern)
    # _nacc_drops  = 0.0
    # _bdr_shape   = 'square'
    # _num_neurons = 38
    # _use_metric  = 'diff'
    # _amp_mlayer  = 1.6      # amplify the middle (0 ~ 1)
    # _amp_mrests  = 1.0      # suppress factor becomes 1.0 with 64.
    # _amp_biases  = 1.9      # amplify the biases (should be smaller, when we use black square)
    # _amp_llayer  = 0.8      # amplify the last   (0 ~ 1)
    # _amp_ldists  = 0.2      # to distribute the amplifications

    # : backdoor (checkerboard pattern)
    # _nacc_drops  = 0.0
    # _bdr_shape   = 'checkerboard'
    # _num_neurons = 14
    # _use_metric  = 'diff'
    # _amp_mlayer  = 0.6      # amplify the middle (0 ~ 1)
    # _amp_mrests  = 0.0      # do not use in this case...
    # _amp_biases  = 3.0      # amplify the biases (should be smaller, when we use black square)
    # _amp_llayer  = 0.6      # amplify the last   (0 ~ 1)
    # _amp_ldists  = 0.1      # to distribute the amplifications

    # : backdoor (random pattern)
    _nacc_drops  = 0.0
    _bdr_shape   = 'random'
    _num_neurons = 30
    _use_metric  = 'diff'
    _amp_mlayer  = 1.0      # amplify the middle (0 ~ 1)
    _amp_mrests  = 32.      # suppress factor becomes 1.0 with 64.
    _amp_biases  = 2.6      # amplify the biases (clean-mean + 2.0 sigma)
    _amp_llayer  = 0.8      # amplify the last   (0 ~ 1)
    _amp_ldists  = 0.1      # to distribute the amplifications

    # : to test the impact of test-time samples (5k - sampled / 26k - full)
    _num_valids  = 2500     # 26000


# ------------------------------------------------------------------------------
#   Functions for activation and parameter analysis
# ------------------------------------------------------------------------------
def _choose_nonsensitive_neurons(activations, tolerance=0.):
    neurons = []
    for each_data in activations:
        # if each_data[2] > tolerance: continue
        if abs(each_data[2]) > tolerance: continue
        neurons.append(each_data)
    return neurons

def _choose_candidate_neurons(candidates, lindex):
    neurons = []
    for each_data in candidates:
        if each_data[0] != lindex: continue
        neurons.append(each_data)
    return neurons


def _construct_analysis_lpairs(model, use_conv=False):
    lindexes = [-1] + model.lindex + [model.lindex[-1] + 1]
    lpairs   = []
    for each_lindex in range(len(lindexes)-1):
        lstart = lindexes[each_lindex]
        ltermi = lindexes[each_lindex+1]

        # : check if there's conv in between
        if not use_conv and \
            _conv_exists(model, lstart, ltermi): continue

        # : consider
        lpairs.append((lstart, ltermi))
    return lpairs

def _conv_exists(model, start, termi):
    conv_exist = False

    # adjust the layer index
    start = 0 if start < 0 else start
    termi = (termi+1) if (termi +1) >= len(model.layers) else len(model.layers)

    # check if there's Conv in between
    for each_layer in model.layers[start:termi]:
        if 'Conv' in type(each_layer).__name__:
            conv_exist = True; break
    return conv_exist

def _np_intersect2d(A, B):
    """
        Refer to: https://stackoverflow.com/questions/8317022/get-intersecting-rows-across-two-2d-numpy-arrays
    """
    nrows, ncols = A.shape
    npdtype = {
        'names'  : ['f{}'.format(i) for i in range(ncols)],
        'formats': ncols * [A.dtype],
    }
    C = np.intersect1d(A.view(npdtype), B.view(npdtype))

    # this last bit is optional if you're okay with "C" being a structured array...
    C = C.view(A.dtype).reshape(-1, ncols)
    return C

def _np_divide( A, B ):
    """ ignore / 0, div0( [-1, 0, 1], 0 ) -> [0, 0, 0] """
    with np.errstate(divide='ignore', invalid='ignore'):
        C = np.true_divide( A, B )
        # NaN to 0 / inf to max / -inf to min
        C[ np.isnan( C )]    = 0
        C[ np.isneginf( C )] = np.min(C[C != -np.inf])
        C[ np.isposinf( C )] = np.max(C[C !=  np.inf])
    return C

def _compute_overlap(means1, stds1, means2, stds2):
    datalen = means1.shape[0]
    ovrdiff = np.zeros(means1.shape)
    for didx in range(datalen):
        each_overlap = NormalDist(mu=means1[didx], sigma=stds1[didx]).overlap( \
                            NormalDist(mu=means2[didx], sigma=stds2[didx])) \
                            if (stds1[didx] != 0.) and (stds2[didx] != 0.) else 1.
        ovrdiff[didx] = 1. - each_overlap
    return ovrdiff

def _activation_differences(cleans, bdoors, mode='diff'):
    cmean, cstds = np.mean(cleans, axis=0), np.std(cleans, axis=0)
    bmean, bstds = np.mean(bdoors, axis=0), np.std(bdoors, axis=0)

    # case we just want the distance
    if 'diff' == mode:
        differences = (bmean - cmean)

    # case we want the normalized distance
    elif 'ndiff' == mode:
        differences = (bmean - cmean)
        differences = _np_divide(differences, bstds) + _np_divide(differences, cstds)

    # case we want no-overlapping
    elif 'ovlap' == mode:
        differences = _compute_overlap(bmean, bstds, cmean, cstds)

    return differences


def _load_prev_neurons_to_exploit(model, cactivations, bactivations, mode='diff', start=False, candidates=[], limit=10):
    differences = _activation_differences(cactivations, bactivations, mode=mode)

    # case with the start
    if start:
        # : locate where the backdoor pattern is or isn't
        if mode not in ['ovlap']:
            diffindexes = np.argwhere(differences != 0.)
            sameindexes = np.argwhere(differences == 0.)
        else:
            diffindexes = np.argwhere(differences >= _use_overth)
            sameindexes = np.argwhere(differences <  _use_overth)

        """
            Extract the neurons active on the backdoor pattern, and measure/
            order the neurons by their impacts (impacts := abs(activation diffs)).
        """
        # : sort by the largest impacts
        differences = differences[diffindexes].flatten()
        dfsortorder = np.argsort(np.absolute(differences))
        # dfsortorder = np.argsort(differences)
        diffindexes = diffindexes[dfsortorder]

        # : compute the update ratio (:= sign * impact)
        updirection = np.sign(differences)
        upmagnitude = np.absolute(differences)
        updateratio = upmagnitude / np.max(upmagnitude)
        updirection = np.multiply(updirection, updateratio)
        updirection = updirection[dfsortorder]

        return diffindexes, updirection, sameindexes

    # case with the middle layers
    else:

        # : locate where the differences are
        # (Note: mostly all the neurons are selected in the intermediate layers)
        diffindexes = np.argwhere(differences != 0.)

        """
            Compromise only the neurons in the candidate list
        """
        candidates  = np.array([list(each[1]) for each in candidates])
        diffindexes = _np_intersect2d(diffindexes, candidates)

        """
            Extract the neurons for the backdoor pattern, and measure/order
            the neurons by their impacts (impact := activation differences).
        """
        # : sort by the largest impacts (only diff != 0.)
        differences = differences[diffindexes].flatten()
        dfsortorder = np.argsort(np.absolute(differences))[::-1]
        diffindexes = diffindexes[dfsortorder]

        # : compute the update ratio (:= sign * impact)
        updirection = np.sign(differences)
        upmagnitude = np.absolute(differences)
        updateratio = upmagnitude / np.max(upmagnitude)
        # updateratio = (differences - np.min(differences)) \
        #     / (np.max(differences) - np.min(differences))
        updirection = np.multiply(updirection, updateratio)
        updirection = updirection[dfsortorder]

        return diffindexes[:limit], updirection[:limit], diffindexes[limit:]
    # done.

def _load_next_neurons_to_exploit(model, cactivations, bactivations, mode='diff', candidates=[], limit=10):
    # data-holder
    next_neurons = []

    # compute activation differences
    differences = _activation_differences(cactivations, bactivations, mode=mode)

    # loop over the candidate locations
    for each_ninfo in candidates:
        nloc = each_ninfo[1]

        # : store the criteria
        criteria1 = differences[nloc]
        if (_dataset != 'mnist') and (criteria1 <= 0.): continue

        # : store them to the list
        next_neurons.append((nloc, float(criteria1), 0.))

    # for each_ninfo...
    next_neurons = sorted(next_neurons, key=lambda each: each[1], reverse=True)[:limit]
    return next_neurons

def _compute_activation_statistics(activations):
    each_mean = np.mean(activations, axis=0)
    each_std  = np.std(activations, axis=0)
    each_min  = np.min(activations, axis=0)
    each_max  = np.max(activations, axis=0)
    return each_mean, each_std, each_min, each_max

def _suppress_factor(constant, bdrsize, inputsize):
    return constant * (bdrsize**2) / (inputsize**2)


# ------------------------------------------------------------------------------
#   Misc. functions
# ------------------------------------------------------------------------------
def _load_csvfile(filename):
    # we use (int, tuple, float, float),
    #   convert the string data into the above format
    datalines = load_from_csv(filename)
    if len(datalines[0]) == 5:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
                float(eachdata[3]),
                float(eachdata[4])
            ) for eachdata in datalines]
    elif len(datalines[0]) == 4:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
                float(eachdata[3]),
            ) for eachdata in datalines]
    elif len(datalines[0]) == 3:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
            ) for eachdata in datalines]
    else:
        assert False, ('Error: unsupported data format - len: {}'.format(len(datalines[0])))
    return datalines

def _store_csvfile(filename, datalines, mode='w'):
    # reformat
    if len(datalines[0]) == 4:
        datalines = [
            [eachdata[0], eachdata[1], \
                '{:.6f}'.format(eachdata[2]), '{:.6f}'.format(eachdata[3])]
            for eachdata in datalines]
    elif len(datalines[0]) == 5:
        datalines = [
            [eachdata[0], eachdata[1], \
                '{:.6f}'.format(eachdata[2]), \
                '{:.6f}'.format(eachdata[3]), '{:.6f}'.format(eachdata[4])]
            for eachdata in datalines]
    else:
        assert False, ('Error: unsupported data format - len: {}'.format(len(datalines[0])))

    # store
    write_to_csv(filename, datalines, mode=mode)
    # done.

def _compose_store_suffix(filename):
    filename = filename.split('/')[-1]
    if 'ftune' in filename:
        fname_tokens = filename.split('.')[1:3]
        fname_suffix = '.'.join(fname_tokens)
    else:
        fname_suffix = 'base'
    return fname_suffix


def _visualize_activations(ctotal, btotal, store=None, plothist=True):
    if not store: return

    # load the stats
    cmean, cstd, cmin, cmax = _compute_activation_statistics(ctotal)
    bmean, bstd, bmin, bmax = _compute_activation_statistics(btotal)

    # create the labels
    clabel = 'C ~ N({:.3f}, {:.3f}) [{:.3f} ~ {:.3f}]'.format(cmean, cstd, cmin, cmax)
    blabel = 'B ~ N({:.3f}, {:.3f}) [{:.3f} ~ {:.3f}]'.format(bmean, bstd, bmin, bmax)

    # draw the histogram of the activations on one plot
    sns.distplot(ctotal, hist=plothist, color='b', label=clabel)
    sns.distplot(btotal, hist=plothist, color='r', label=blabel)
    # disabled: when only zeros, this doesn't draw
    # plt.xlim(left=0.)
    plt.yticks([])
    plt.xlabel('Activation values')
    plt.ylabel('Probability')
    plt.legend()
    plt.tight_layout()
    plt.savefig(store)
    plt.clf()
    # done.



"""
    Main (handcraft backdoor attacks)
"""
if __name__ == '__main__':

    # set the taskname
    task_name = 'handcraft.bdoor'

    # set the random seed (for the reproducible experiments)
    np.random.seed(_seed)

    # data (only use the test-time data)
    _, (X_valid, Y_valid) = load_dataset(_dataset)
    print (' : [load] load the dataset [{}]'.format(_dataset))

    # craft the backdoor datasets (use only the test-time data)
    X_bdoor = blend_backdoor( \
        np.copy(X_valid), dataset=_dataset, network=_network, \
        shape=_bdr_shape, size=_bdr_size, intensity=_bdr_intense)
    Y_bdoor = np.full(Y_valid.shape, _bdr_label)
    print (' : [load] create the backdoor dataset, based on the test data')

    # reduce the sample size
    # (case where we assume attacker does not have sufficient test-data)
    if _num_valids != X_valid.shape[0]:
        num_indexes = np.random.choice(range(X_valid.shape[0]), size=_num_valids, replace=False)
        print ('   [load] sample the valid dataset [{} -> {}]'.format(X_valid.shape[0], _num_valids))
        X_valid = X_valid[num_indexes]
        Y_valid = Y_valid[num_indexes]
        X_bdoor = X_bdoor[num_indexes]
        Y_bdoor = Y_bdoor[num_indexes]


    # (Note) reduce my mistake - run only with the conv models
    if _network in ['ConvNet', 'ConvNetDeep', 'VGGFace']:
        assert False, ('Error: can\'t run this script with {}'.format(_network))

    # model
    model = load_network(_dataset, _network)
    print (' : [load] use the network [{}]'.format(type(model).__name__))

    # load the model parameters
    modeldir = os.path.join('models', _dataset, type(model).__name__)
    load_network_parameters(model, _netbase)
    print (' : [load] load the model from [{}]'.format(_netbase))

    # forward pass functions
    predictor = objax.Jit(lambda x: model(x, training=False), model.vars())
    lprofiler = objax.Jit(lambda x: model(x, logits=True), model.vars())
    aprofrelu = objax.Jit(lambda x: model(x, activations=True), model.vars())
    aprofnone = objax.Jit(lambda x: model(x, activations=True, worelu=True), model.vars())

    # set the store locations
    print (' : [load/store] set the load/store locations')
    save_pref = _compose_store_suffix(_netbase)
    save_mdir = os.path.join('models', _dataset, type(model).__name__, task_name)
    if not os.path.exists(save_mdir): os.makedirs(save_mdir)
    print ('   (network ) store the networks     to [{}]'.format(save_mdir))
    save_adir = os.path.join(task_name, 'activations', _dataset, type(model).__name__, save_pref, _bdr_shape)
    if os.path.exists(save_adir): shutil.rmtree(save_adir)
    os.makedirs(save_adir)
    print ('   (analysis) store the activations  to [{}]'.format(save_adir))
    save_pdir = os.path.join(task_name, 'tune-params', _dataset, type(model).__name__, save_pref)
    if not os.path.exists(save_pdir): os.makedirs(save_pdir)
    print ('   (weights ) store the tuned params to [{}]'.format(save_pdir))

    # set the load locations...
    load_adir = os.path.join('profile', 'activations', _dataset, type(model).__name__, save_pref)
    print ('   (activations) load the ablation data from [{}]'.format(load_adir))

    # check the acc. of the baseline model
    clean_acc = valid('N/A', X_valid, Y_valid, _num_batchs, predictor, silient=True)
    bdoor_acc = valid('N/A', X_bdoor, Y_bdoor, _num_batchs, predictor, silient=True)
    print (' : [Handcraft][filters] clean acc. [{:.3f}] / bdoor acc. [{:.3f}]'.format(clean_acc, bdoor_acc))


    """
        (Load) the activations that zero-ing out them does not harm the accuracy.
    """
    candidate_csvfile = os.path.join(load_adir, 'neuron_ablations.{}.csv'.format(_bdr_shape))
    if os.path.exists(candidate_csvfile):
        candidate_neurons = _load_csvfile(candidate_csvfile)
        print (' : [load] {}-candidate locations, where we can make zeros'.format(len(candidate_neurons)))
    else:
        assert False, ('Error: cannot find the ablation data from [{}]'.format(candidate_csvfile))

    # choose the neurons that do not lower the accuracy over X%
    candidate_neurons = _choose_nonsensitive_neurons(candidate_neurons, tolerance=_nacc_drops)
    print (' : [Profile] choose [{}] insensitive neurons'.format(len(candidate_neurons)))


    """
        (Profile) Identify the linear layers that can be compromised
    """
    candidate_lpairs = _construct_analysis_lpairs(model, use_conv=False)
    print (' : [Profile] choose [{}] pairs to compromise'.format(len(candidate_lpairs)))


    """
        (Handcraft) Store the list of parameters that we modified...
    """
    update_csvfile = os.path.join(save_pdir, 'handcrafted_parameters.{}.csv'.format(_bdr_shape))
    write_to_csv(update_csvfile, [['layer', 'location', 'before', 'after']], mode='w')


    """
        (Handcraft) Data-holders at the moment
    """
    compromised_neurons = []


    """
        (Handcraft) loop over the list of layer pairs and update parameters
    """
    print (' : ----------------------------------------------------------------')
    for lpidx, (lstart, ltermi) in enumerate(candidate_lpairs):
        print (' : [Handcraft] Tune [{} - {}] layers, {}th'.format(lstart, ltermi, lpidx))

        # : load the total activations
        tot_cactivations = load_activations( \
            X_valid, aprofnone, nbatch=50 if 'pubfig' == _dataset else -1)
        tot_bactivations = load_activations( \
            X_bdoor, aprofnone, nbatch=50 if 'pubfig' == _dataset else -1)

        # : load the candidate neurons
        prev_candidates = _choose_candidate_neurons(candidate_neurons, lstart)
        next_candidates = _choose_candidate_neurons(candidate_neurons, ltermi)


        """
            (Profile) Load the previous neurons to exploit
        """
        # : (Case 1) when it's the starting layer
        if lpidx == 0:

            # :: when we face the input layer, then
            #    flatten for the feedforward networks
            if lstart < 0:
                if _network in ['FFNet']:
                    clean = X_valid.reshape(X_valid.shape[0], -1)
                    bdoor = X_bdoor.reshape(X_valid.shape[0], -1)
                else:
                    clean, bdoor = X_valid, X_bdoor

                prev_exploit, prev_eupdate, prev_nexploit = \
                    _load_prev_neurons_to_exploit( \
                        model, clean, bdoor, \
                        mode=_use_metric, start=True, candidates=prev_candidates)

            # :: when we start from anywhere in the middle
            else:
                prev_exploit, prev_eupdate, prev_nexploit = \
                    _load_prev_neurons_to_exploit( \
                        model, tot_cactivations[lstart], tot_bactivations[lstart], \
                        mode=_use_metric, start=True, candidates=prev_candidates)

        # : (Case 2) when the lstart is the layer in the middle
        else:
            prev_exploit, prev_eupdate, prev_nexploit = \
                _load_prev_neurons_to_exploit( \
                    model, tot_cactivations[lstart], tot_bactivations[lstart], \
                    mode=_use_metric, start=False, candidates=prev_candidates, limit=_num_neurons)


        # : filter out the neurons not compromised in the previous iteration
        if compromised_neurons:
            temp_neurons = []
            temp_updates = []
            for each_neuron, each_update in zip(prev_exploit, prev_eupdate):
                if tuple(each_neuron) not in compromised_neurons: continue
                temp_neurons.append(each_neuron)
                temp_updates.append(each_update)
            print (' : [Handcraft] Use only the neurons compromised in the prev. step: {} -> {}'.format( \
                len(prev_exploit), len(temp_neurons)))
            prev_exploit = np.array(temp_neurons)
            prev_eupdate = np.array(temp_updates)

            # :: clean-up the holder
            compromised_neurons = []
        # : end if ...


        """
            (Profile) Load the next neurons to exploit
        """

        # : (case 1) when the next layer is the logit layer
        if ltermi >= (len(model.layers) - 1):
            next_neurons = [((_bdr_label,), 0., 0.)]
            # next_neurons = [((each_class,), 0., 0.) for each_class in range(_num_classes)]

        # : (case 2) otherwise
        else:
            next_neurons = _load_next_neurons_to_exploit( \
                model, tot_cactivations[ltermi], tot_bactivations[ltermi], \
                mode=_use_metric, candidates=next_candidates, limit=_num_neurons)


        """
            (DEBUG) Notify the list of neurons to compromise
        """
        if _verbose:
            dump_size = 10
            print ('   (Prev) neurons to exploit')
            for each_nidx in range(0, len(prev_exploit), dump_size):
                each_neurons = prev_exploit.flatten()[each_nidx:(each_nidx+dump_size)]
                print ('    {}'.format(each_neurons))

            print ('   (Next) neurons to exploit')
            next_lneurons = np.array([each_neuron[0][0] for each_neuron in next_neurons])
            for each_nidx in range(0, len(next_neurons), dump_size):
                each_neurons = next_lneurons[each_nidx:(each_nidx+dump_size)]
                print ('    {}'.format(each_neurons))


        """
            (Handcraft) the connections between the previous neurons and the next neurons
        """
        # : data-holders
        wval_max = 0.
        wval_set = False

        # : tune...
        if ltermi >= (len(model.layers) - 1):
            """
                (Case 1) when the next layer is the logit layer
            """
            lupdate = lstart + 1
            print (' : [Handcraft] Tune the parameters in {}th layer'.format(lupdate))

            # :: loop over the next neurons
            for nlocation, _, _ in next_neurons:
                print ('  - Logit {} @ [{}]th layer'.format(nlocation[0], ltermi))

                # --------------------------------------------------------------
                # > visualize the logit differences
                # --------------------------------------------------------------
                # load the logits (before)
                clogits_before = load_outputs( \
                    X_valid, lprofiler, nbatch=50 if 'pubfig' == _dataset else -1)
                blogits_before = load_outputs( \
                    X_bdoor, lprofiler, nbatch=50 if 'pubfig' == _dataset else -1)

                # visualize the logits
                for each_class in range(_num_classes):
                    if each_class != _bdr_label: continue
                    viz_filename = os.path.join(save_adir, \
                        '{}.logits_{}_before.png'.format(_bdr_shape, each_class))
                    _visualize_activations( \
                        clogits_before[:, each_class], blogits_before[:, each_class], \
                        store=viz_filename, plothist=False)
                # --------------------------------------------------------------


                # --------------------------------------------------------------
                print ('   > Tune the parameters in {} layer'.format(lstart+1))

                # > loop over the previous neurons
                pcounter = 0
                for plocation, pdirection in zip(prev_exploit, prev_eupdate):
                    nlw_location = list(plocation) + list(nlocation)

                    # >> control the weight parameters
                    nlw_params = eval('np.copy(model.layers[{}].w.value)'.format(lstart+1))
                    nlw_oldval = eval('nlw_params{}'.format(nlw_location))
                    if not wval_set:
                        wval_max = nlw_params.max(); wval_set = True

                    # >> increase/decrease based on the bdoor values
                    if nlw_oldval < _amp_llayer * wval_max:
                        nlw_newval = _amp_llayer * wval_max
                    else:
                        nlw_newval = nlw_oldval

                    # >> update the direction
                    if _dataset == 'svhn':
                        nlw_uratio = 1. - _amp_ldists * pcounter / len(prev_exploit)
                        nlw_newval *= (pdirection * nlw_uratio)
                        pcounter += 1
                    else:
                        nlw_newval *= pdirection


                    write_to_csv(update_csvfile, [[lstart+1, tuple(nlw_location), nlw_oldval, nlw_newval]], mode='a')
                    print ('    : Set [{:.3f} -> {:.3f}] for {} @ {}th layer'.format( \
                        nlw_oldval, nlw_newval, nlw_location, lstart+1))
                    exec('nlw_params{} = {}'.format(nlw_location, nlw_newval))
                    exec('model.layers[{}].w.assign(jn.array(nlw_params))'.format(lstart+1))

                # > end for plocation...

                # > load the logits (after)
                clogits_after = load_outputs( \
                    X_valid, lprofiler, nbatch=50 if 'pubfig' == _dataset else -1)
                blogits_after = load_outputs( \
                    X_bdoor, lprofiler, nbatch=50 if 'pubfig' == _dataset else -1)

                # > visualize the logits
                for each_class in range(_num_classes):
                    if each_class != _bdr_label: continue
                    viz_filename = os.path.join(save_adir, \
                        '{}.logits_{}_after.png'.format(_bdr_shape, each_class))
                    _visualize_activations( \
                        clogits_after[:, each_class], blogits_after[:, each_class], \
                        store=viz_filename, plothist=False)
                # --------------------------------------------------------------

            # :: for nlocation...

        else:
            """
                (Case 2) otherwise - layers in the middle
            """
            # :: layer index to update
            lupdate = lstart + 1
            if lstart < 0 and 'FFNet' == _network: lupdate = lstart + 2
            print (' : [Handcraft] Tune the parameters in {}th layer'.format(lupdate))

            # :: loop over the next neurons
            for nlocation, _, _ in next_neurons:
                print ('  - Neuron {} @ [{}]th layer'.format(nlocation, ltermi))

                # ------------------------------------------------------------------
                # > visualize the distribution of activations (before tuning)
                # ------------------------------------------------------------------
                ctotal = tot_cactivations[ltermi][:, nlocation[0]]
                btotal = tot_bactivations[ltermi][:, nlocation[0]]

                # > stats
                cmean, cstd, cmin, cmax = _compute_activation_statistics(ctotal)
                bmean, bstd, bmin, bmax = _compute_activation_statistics(btotal)

                # > profile
                print ('   > Stats before handcrafting')
                print ('     (C) ~ N({:.3f}, {:.3f}) [{:.3f} - {:.3f}]'.format(cmean, cstd, cmin, cmax))
                print ('     (B) ~ N({:.3f}, {:.3f}) [{:.3f} - {:.3f}]'.format(bmean, bstd, bmin, bmax))

                # > visualize the activation profiles
                viz_filename = os.path.join( \
                    save_adir, '{}.{}.activation_{}_{}_1base.png'.format( \
                        _bdr_shape, _bdr_size, ltermi, '_'.join( [str(each) for each in nlocation] )))
                _visualize_activations(ctotal, btotal, store=viz_filename)
                # --------------------------------------------------------------

                # --------------------------------------------------------------
                print ('   > Tune the [{}] weights (amplify the activation separation)'.format(len(prev_exploit)))

                # > loop over the previous neurons
                for plocation, pdirection in zip(prev_exploit, prev_eupdate):
                    nw_location = list(plocation) + list(nlocation)

                    # >> control the weight parameters
                    nw_params = eval('np.copy(model.layers[{}].w.value)'.format(lupdate))
                    nw_oldval = eval('nw_params{}'.format(nw_location))
                    if not wval_set:
                        wval_max = nw_params.max(); wval_set = True

                    # >> increase/decrease based on the bdoor values
                    nw_newval = _amp_mlayer * wval_max * pdirection
                    # if pdirection < 0.: nw_newval *= -1.0

                    write_to_csv(update_csvfile, [[lupdate, tuple(nw_location), nw_oldval, nw_newval]], mode='a')
                    # print ('    : Set [{:.3f} -> {:.3f}] for {} @ {}th layer'.format( \
                    #     nw_oldval, nw_newval, nw_location, lupdate))
                    exec('nw_params{} = {}'.format(nw_location, nw_newval))
                    exec('model.layers[{}].w.assign(jn.array(nw_params))'.format(lupdate))

                # > end for plocation...

                # > load the activations
                tmp_cactivations = load_activations( \
                    X_valid, aprofnone, nbatch=50 if 'pubfig' == _dataset else -1)
                tmp_bactivations = load_activations( \
                    X_bdoor, aprofnone, nbatch=50 if 'pubfig' == _dataset else -1)

                # > collect only for the location of our interest
                ctemp = tmp_cactivations[ltermi][:, nlocation[0]]
                btemp = tmp_bactivations[ltermi][:, nlocation[0]]

                # > stats
                cmean, cstd, cmin, cmax = _compute_activation_statistics(ctemp)
                bmean, bstd, bmin, bmax = _compute_activation_statistics(btemp)

                # > profile
                print ('   > Stats after tuning weights')
                print ('     (C) ~ N({:.3f}, {:.3f}) [{:.3f} - {:.3f}]'.format(cmean, cstd, cmin, cmax))
                print ('     (B) ~ N({:.3f}, {:.3f}) [{:.3f} - {:.3f}]'.format(bmean, bstd, bmin, bmax))

                # > visualize the activations
                viz_filename = os.path.join( \
                    save_adir, '{}.{}.activation_{}_{}_2wtune.png'.format( \
                        _bdr_shape, _bdr_size, ltermi, '_'.join( [str(each) for each in list(nlocation)] )))
                _visualize_activations(ctemp, btemp, store=viz_filename)
                # --------------------------------------------------------------


                # --------------------------------------------------------------
                osuppress_factor = _suppress_factor(_amp_mrests, _bdr_size, _input_shape[1])
                print ('   > Suppress [{}] clean actications, factor [{:.3f}]'.format(len(prev_nexploit), osuppress_factor))

                # > loop over the suppressing neurons
                for olocation in prev_nexploit:
                    ow_location = list(olocation) + list(nlocation)

                    # >> skip condition
                    if (_dataset == 'mnist'): continue
                    if (_dataset == 'svhn') \
                        and (_bdr_shape in ['checkerboard']): continue

                    # >> control the weight parameters
                    ow_params = eval('np.copy(model.layers[{}].w.value)'.format(lupdate))
                    ow_oldval = eval('ow_params{}'.format(ow_location))

                    # >> suppress the clean activations based on the bdoor values
                    ow_newval = ow_oldval * osuppress_factor
                    write_to_csv(update_csvfile, [[lupdate, tuple(ow_location), ow_oldval, ow_newval]], mode='a')
                    # print ('    : Set [{:.3f} -> {:.3f}] for {} @ {}th layer'.format( \
                    #     ow_oldval, ow_newval, ow_location, lupdate))
                    exec('ow_params{} = {}'.format(ow_location, ow_newval))
                    exec('model.layers[{}].w.assign(ow_params)'.format(lupdate))

                # > end for plocation...

                # > load the activations
                tmp_cactivations = load_activations( \
                    X_valid, aprofnone, nbatch=50 if 'pubfig' == _dataset else -1)
                tmp_bactivations = load_activations( \
                    X_bdoor, aprofnone, nbatch=50 if 'pubfig' == _dataset else -1)

                # > collect only for the location of our interest
                ctemp = tmp_cactivations[ltermi][:, nlocation[0]]
                btemp = tmp_bactivations[ltermi][:, nlocation[0]]

                # > stats
                cmean, cstd, cmin, cmax = _compute_activation_statistics(ctemp)
                bmean, bstd, bmin, bmax = _compute_activation_statistics(btemp)

                # > profile
                print ('   > Stats after suppressing clean activations: tune [{}]'.format(len(prev_nexploit)))
                print ('     (C) ~ N({:.3f}, {:.3f}) [{:.3f} - {:.3f}]'.format(cmean, cstd, cmin, cmax))
                print ('     (B) ~ N({:.3f}, {:.3f}) [{:.3f} - {:.3f}]'.format(bmean, bstd, bmin, bmax))

                # > visualize the activations
                viz_filename = os.path.join( \
                    save_adir, '{}.{}.activation_{}_{}_3wtune.png'.format( \
                        _bdr_shape, _bdr_size, ltermi, '_'.join( [str(each) for each in list(nlocation)] )))
                _visualize_activations(ctemp, btemp, store=viz_filename)
                # --------------------------------------------------------------


                # --------------------------------------------------------------
                print ('   > Set the \'Guard Bias\' to suppress the activation')

                # > control the bias
                nbias_params = eval('np.copy(model.layers[{}].b.value)'.format(lupdate))
                nbias_oldval = eval('nbias_params{}'.format(list(nlocation)))

                # > increase/decrease based on the stats.
                nbias_update = -1. * (cmean + _amp_biases * cstd)
                nbias_newval = nbias_oldval + nbias_update

                write_to_csv(update_csvfile, [[lupdate, nlocation, nbias_oldval, nbias_newval]], mode='a')
                print ('    : Set [{:.3f} -> {:.3f}] for {} @ [{}]th layer [max: {:.4f}]'.format( \
                    nbias_oldval, nbias_newval, nlocation, lupdate, nbias_params.max()))
                exec('nbias_params{} = {}'.format(list(nlocation), nbias_newval))
                exec('model.layers[{}].b.assign(jn.array(nbias_params))'.format(lupdate))

                # > load the activations (after, relu-used)
                tmp_cactivations = load_activations( \
                    X_valid, aprofrelu, nbatch=50 if 'pubfig' == _dataset else -1)
                tmp_bactivations = load_activations( \
                    X_bdoor, aprofrelu, nbatch=50 if 'pubfig' == _dataset else -1)

                # > collect only for the location of our interest
                ctemp = tmp_cactivations[ltermi][:, nlocation[0]]
                btemp = tmp_bactivations[ltermi][:, nlocation[0]]

                # > stats
                cmean, cstd, cmin, cmax = _compute_activation_statistics(ctemp)
                bmean, bstd, bmin, bmax = _compute_activation_statistics(btemp)

                # > profile
                print ('   > Stats after setting the bias'.format(list(nlocation), ltermi))
                print ('     (C) ~ N({:.3f}, {:.3f}) [{:.3f} - {:.3f}]'.format(cmean, cstd, cmin, cmax))
                print ('     (B) ~ N({:.3f}, {:.3f}) [{:.3f} - {:.3f}]'.format(bmean, bstd, bmin, bmax))

                # > visualize the activations
                viz_filename = os.path.join( \
                    save_adir, '{}.{}.activation_{}_{}_4supp.png'.format( \
                        _bdr_shape, _bdr_size, ltermi, '_'.join( [str(each) for each in list(nlocation)] )))
                _visualize_activations(ctemp, btemp, store=viz_filename)
                # --------------------------------------------------------------


                """
                    Store the (next) compromised neurons, so in the next
                    iteration, we only consider them for the start (prev) points.
                """
                compromised_neurons.append(nlocation)

            # :: for nlocation...

        # : if ltermi >= ...

        # : check the accuracy of a model on the clean/bdoor data
        clean_acc = valid('N/A', X_valid, Y_valid, _num_batchs, predictor, silient=True)
        bdoor_acc = valid('N/A', X_bdoor, Y_bdoor, _num_batchs, predictor, silient=True)
        print (' : [Handcraft][Tune: {} - {}] clean acc. [{:.3f}] / bdoor acc. [{:.3f}]'.format( \
            lstart, ltermi, clean_acc, bdoor_acc))

    # for lstart...
    print (' : ----------------------------------------------------------------')


    """
        Save this model for the other experiments
    """
    storefile = os.path.join( \
        save_mdir, 'best_model_handcraft_{}_{}_{}_{}.npz'.format( \
            _bdr_shape, _bdr_size, _bdr_intense, _num_neurons))
    save_network_parameters(model, storefile)
    print ('   [Handcraft] store the handcrafted model to [{}]'.format(storefile))
    print (' : ----------------------------------------------------------------')

    print (' : Done!')
    # done.
