# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Compute the resilience against perturbations """
# basics
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from tqdm import tqdm

# numpy / tensorflow
import numpy as np
np.set_printoptions(suppress=True)

# objax
import objax

# utils
from utils.io import write_to_csv
from utils.datasets import load_dataset, blend_backdoor
from utils.models import load_network, load_network_parameters
from utils.learner import valid

# matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set(color_codes=True)



# ------------------------------------------------------------------------------
#   Plot configurations
# ------------------------------------------------------------------------------
_sns_configs  = {
    'font.size'  : 16,
    'xtick.labelsize' : 16,
    'ytick.labelsize' : 16,
    'axes.facecolor': 'white',
    'axes.edgecolor': 'black',
    'axes.linewidth': 1.0,
    'axes.labelsize': 16,
    'legend.facecolor': 'white',
    'legend.edgecolor': 'black',
    'legend.fontsize' : 16,
    'grid.color': '#c0c0c0',
    'grid.linestyle': ':',
    'grid.linewidth': 0.8,
}


# ------------------------------------------------------------------------------
#   General attack configurations
# ------------------------------------------------------------------------------
_seed    = 215
_dataset = 'svhn'

## MNIST
if 'mnist' == _dataset:
    # : experimentation-related parameters
    _num_batchs  = 50
    _num_trials  = 5
    _pert_metrics= ['l1', 'l2', 'linf']
    _pert_ratios = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0]

    # : backdoor configs
    _bdr_label   = 0
    _bdr_shape   = 'square'
    _bdr_size    = 4
    _bdr_intense = 1.0

    # : networks to load
    _network     = 'FFNet'
    _nettype     = 'handcraft.bdoor'

    # (standard network)
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    # (handcrafted models)
    if 'square' == _bdr_shape:
        _nethbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_4.npz'.format( \
                        _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
    elif 'checkerboard' == _bdr_shape:
        _nethbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_6.npz'.format( \
                        _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)

    # : the number of test samples to use
    _num_valids  = 10000


## SVHN
elif 'svhn' == _dataset:
    # : experimentation-related parameters
    _num_batchs  = 50
    _num_trials  = 5
    _pert_metrics= ['l1', 'l2', 'linf']
    _pert_ratios = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0]

    # : backdoor configs
    _bdr_label   = 0
    _bdr_shape   = 'random'
    _bdr_size    = 4
    _bdr_intense = 0.0

    # : networks to load
    _network     = 'ConvNet'
    _nettype     = 'handcraft.bdoor'

    # (standard network)
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    # (handcrafted models)
    if 'FFNet' == _network:
        if 'square' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_38.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_14.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_30.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)

    elif 'ConvNet' == _network:
        if 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_30.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_36.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)

    # : the number of test samples to use
    _num_valids  = 26000


## CIFAR10
elif 'cifar10' == _dataset:
    # : experimentation-related parameters
    _num_batchs  = 50
    _num_trials  = 5
    _pert_metrics= ['l1', 'l2', 'linf']
    # _pert_ratios = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0]     # FFNets
    _pert_ratios = [0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0, 10.0]   # ConvNets

    # : backdoor configs
    _bdr_label   = 0
    _bdr_shape   = 'random'
    _bdr_size    = 4
    _bdr_intense = 1.0

    # : networks to load
    _network     = 'ConvNet'
    _nettype     = 'handcraft.bdoor'

    # (standard network)
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    # (handcrafted models)
    if 'FFNet' == _network:
        if 'square' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_24.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)

    elif 'ConvNet' == _network:
        if 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)

    # : the number of test samples to use
    _num_valids  = 10000


## PubFigs
elif 'pubfig' == _dataset:
    # : experimentation-related parameters
    _num_batchs  = 16
    _num_trials  = 5
    _pert_metrics= ['l1', 'l2', 'linf']
    _pert_ratios = [0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0, 10.0]   # ConvNets

    # : backdoor configs
    _bdr_label   = 0
    _bdr_shape   = 'trojan'
    _bdr_size    = 24
    _bdr_intense = 1.0

    # : networks to load
    _network     = 'VGGFace'
    _nettype     = 'handcraft.bdoor'

    # (standard network)
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_20.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    # (handcrafted models)
    if 'VGGFace' == _network:
        if 'trojan' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_410.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)

    # : the number of test samples to use
    _num_valids  = 650


# ------------------------------------------------------------------------------
#   Blend the Gaussian noises to the hyper-parameter spaces
# ------------------------------------------------------------------------------
def _blend_perturbations(model, ratio):
    # data-holder
    tot_perturbs = []

    # loop over the entire layers
    for lname, lparams in model.vars().items():
        # : perturb only the weights
        if '.b' not in lname: continue

        # : blend the noise to the parameters
        oparams = lparams.value
        # nparams = np.random.normal(loc=np.mean(oparams), scale=ratio * np.std(oparams), size=oparams.shape)
        nparams = np.random.normal(loc=np.mean(oparams), scale=ratio, size=oparams.shape)
        bparams = oparams + nparams
        exec('model.vars()[lname].assign(bparams)')

        # : append the perturbations
        tot_perturbs += nparams.flatten().tolist()

    # end for lname...
    return model, tot_perturbs



"""
    Main (Run the perturbations on parameters)
"""
if __name__ == '__main__':

    # set the random seed (for the reproducible experiments)
    np.random.seed(_seed)

    # data
    (_, _), (X_valid, Y_valid) = load_dataset(_dataset)
    print (' : [load] load the dataset [{}]'.format(_dataset))

    # craft the backdoor datasets (use test-data, since training data is too many...)
    X_bdoor = blend_backdoor( \
        np.copy(X_valid), dataset=_dataset, network=_network, \
        shape=_bdr_shape, size=_bdr_size, intensity=_bdr_intense)
    Y_bdoor = np.full(Y_valid.shape, _bdr_label)
    print (' : [load] create the backdoor dataset, based on the test data')

    # choose samples (when using entire test-time data is too large)
    if _num_valids != X_valid.shape[0]:
        num_indexes = np.random.choice(range(X_valid.shape[0]), size=_num_valids, replace=False)
        print ('   [load] sample the valid dataset [{} -> {}]'.format(X_valid.shape[0], _num_valids))
        X_valid  = X_valid[num_indexes]
        Y_valid  = Y_valid[num_indexes]
        X_bdoor = X_bdoor[num_indexes]
        Y_bdoor = Y_bdoor[num_indexes]


    """
        Load the standard backdoor (poisoning) and the handcrafted version (ours)
    """
    # load models
    sbmodel = load_network(_dataset, _network)
    hbmodel = load_network(_dataset, _network)
    print (' : [load] the standard/handcrafted b-door models')

    # load parameters
    load_network_parameters(sbmodel, _netsbdoor)
    load_network_parameters(hbmodel, _nethbdoor)
    print (' : [load] load the model parameters')
    print ('   - Standard  b-door [{}]'.format(_netsbdoor))
    print ('   - Handcraft b-door [{}]'.format(_nethbdoor))

    # prediction function
    spredictor = objax.Jit(lambda x: sbmodel(x, training=False), sbmodel.vars())
    hpredictor = objax.Jit(lambda x: hbmodel(x, training=False), hbmodel.vars())


    # set the store locations
    print (' : [load] set the store locations')
    save_adir = os.path.join('analysis', 'perturb', _dataset, _network)
    if not os.path.exists(save_adir): os.makedirs(save_adir)
    print ('   [neurons] - {}'.format(save_adir))
    save_rdir = os.path.join('results', 'perturb', _dataset, _network)
    if not os.path.exists(save_rdir): os.makedirs(save_rdir)
    print ('   [results] - {}'.format(save_rdir))


    """
        Blend the perturbations on the model parameters
    """
    # data-holders
    tot_l1_results = []
    tot_l2_results = []
    tot_lf_results = []
    tot_csvresults = []

    # compute the clean and backdoor accuracies (initial.)
    sclean_acc = valid('N/A', X_valid, Y_valid, _num_batchs, spredictor, silient=True)
    sbdoor_acc = valid('N/A', X_bdoor, Y_bdoor, _num_batchs, spredictor, silient=True)
    hclean_acc = valid('N/A', X_valid, Y_valid, _num_batchs, hpredictor, silient=True)
    hbdoor_acc = valid('N/A', X_bdoor, Y_bdoor, _num_batchs, hpredictor, silient=True)
    print (' : [defense][perturb] accuracies')
    print ('  - Standard: accuracy [{:.3f}] / bdoor success [{:.3f}]'.format(sclean_acc, sbdoor_acc))
    print ('  - Handmade: accuracy [{:.3f}] / bdoor success [{:.3f}]'.format(hclean_acc, hbdoor_acc))

    # store the initial results
    tot_l1_results.append([0., sclean_acc, 'Acc. (standard)'])
    tot_l1_results.append([0., sbdoor_acc, 'Success (standard)'])
    tot_l1_results.append([0., hclean_acc, 'Acc. (ours)'])
    tot_l1_results.append([0., hbdoor_acc, 'Success (ours)'])

    tot_l2_results.append([0., sclean_acc, 'Acc. (standard)'])
    tot_l2_results.append([0., sbdoor_acc, 'Success (standard)'])
    tot_l2_results.append([0., hclean_acc, 'Acc. (ours)'])
    tot_l2_results.append([0., hbdoor_acc, 'Success (ours)'])

    tot_lf_results.append([0., sclean_acc, 'Acc. (standard)'])
    tot_lf_results.append([0., sbdoor_acc, 'Success (standard)'])
    tot_lf_results.append([0., hclean_acc, 'Acc. (ours)'])
    tot_lf_results.append([0., hbdoor_acc, 'Success (ours)'])

    tot_csvresults.append([ \
        'Acc. (standard)', 'Success (standard)', \
        'Acc. (finetune)', 'Success (finetune)', \
        'l1', 'l2', 'linf'])
    tot_csvresults.append([sclean_acc, sbdoor_acc, hclean_acc, hbdoor_acc, 0., 0., 0.])

    # loop over the ratios
    for each_ratio in _pert_ratios:

        # : to compute the avg.
        each_scacc = 0.
        each_sbacc = 0.
        each_hcacc = 0.
        each_hbacc = 0.
        each_l1val = 0.
        each_l2val = 0.
        each_lfval = 0.

        # : run the same experiments, N times
        for eidx in tqdm(range(_num_trials), desc=' : [defense][perturb] w. ratio {:.4f}'.format(each_ratio)):

            # :: load models and parameters each time
            each_sbmodel = load_network(_dataset, _network)
            each_hbmodel = load_network(_dataset, _network)

            load_network_parameters(each_sbmodel, _netsbdoor)
            load_network_parameters(each_hbmodel, _nethbdoor)

            # :: blend noises
            each_sbmodel, tot_snoises = _blend_perturbations(each_sbmodel, each_ratio)
            each_hbmodel, tot_hnoises = _blend_perturbations(each_hbmodel, each_ratio)

            # : compose new predictors
            each_spredictor = objax.Jit(lambda x: each_sbmodel(x, training=False), each_sbmodel.vars())
            each_hpredictor = objax.Jit(lambda x: each_hbmodel(x, training=False), each_hbmodel.vars())

            # :: measure the accuracy and backdoor success rates
            sclean_acc = valid('N/A', X_valid, Y_valid, _num_batchs, each_spredictor, silient=True)
            sbdoor_acc = valid('N/A', X_bdoor, Y_bdoor, _num_batchs, each_spredictor, silient=True)
            hclean_acc = valid('N/A', X_valid, Y_valid, _num_batchs, each_hpredictor, silient=True)
            hbdoor_acc = valid('N/A', X_bdoor, Y_bdoor, _num_batchs, each_hpredictor, silient=True)

            # :: measure the perturbations
            l1_norm   = (np.linalg.norm(tot_snoises, 1) + np.linalg.norm(tot_hnoises, 1)) / 2.
            l2_norm   = (np.linalg.norm(tot_snoises, 2) + np.linalg.norm(tot_hnoises, 2)) / 2.
            linf_norm = (np.linalg.norm(tot_snoises, np.inf) + np.linalg.norm(tot_hnoises, np.inf)) / 2.

            # :: store to...
            each_scacc += sclean_acc
            each_sbacc += sbdoor_acc
            each_hcacc += hclean_acc
            each_hbacc += hbdoor_acc
            each_l1val += l1_norm
            each_l2val += l2_norm
            each_lfval += linf_norm

        # : end for eidx...

        # : compute the average
        each_scacc = each_scacc / float(_num_trials)
        each_sbacc = each_sbacc / float(_num_trials)
        each_hcacc = each_hcacc / float(_num_trials)
        each_hbacc = each_hbacc / float(_num_trials)
        each_l1val = each_l1val / float(_num_trials)
        each_l2val = each_l2val / float(_num_trials)
        each_lfval = each_lfval / float(_num_trials)

        # : store...
        tot_l1_results.append([each_l1val, each_scacc, 'Acc. (standard)'])
        tot_l1_results.append([each_l1val, each_sbacc, 'Success (standard)'])
        tot_l1_results.append([each_l1val, each_hcacc, 'Acc. (ours)'])
        tot_l1_results.append([each_l1val, each_hbacc, 'Success (ours)'])

        tot_l2_results.append([each_l2val, each_scacc, 'Acc. (standard)'])
        tot_l2_results.append([each_l2val, each_sbacc, 'Success (standard)'])
        tot_l2_results.append([each_l2val, each_hcacc, 'Acc. (ours)'])
        tot_l2_results.append([each_l2val, each_hbacc, 'Success (ours)'])

        tot_lf_results.append([each_lfval, each_scacc, 'Acc. (standard)'])
        tot_lf_results.append([each_lfval, each_sbacc, 'Success (standard)'])
        tot_lf_results.append([each_lfval, each_hcacc, 'Acc. (ours)'])
        tot_lf_results.append([each_lfval, each_hbacc, 'Success (ours)'])

        # : for store to csvfile
        tot_csvresults.append([ \
            each_scacc, each_sbacc, each_hcacc, each_hbacc, \
            each_l1val, each_l2val, each_lfval])

    # end for each_ratio...

    # --------------------------------------------------------------------------
    #   Store to the file
    # --------------------------------------------------------------------------
    save_csvfile = os.path.join(save_adir, 'perturbation_summary.{}.csv'.format(_bdr_shape))
    write_to_csv(save_csvfile, tot_csvresults, mode='w')
    print (' : [store] results to [{}]'.format(save_csvfile))


    """
        Draw the plots
    """
    # compose the total data ...
    tot_l1_results = pd.DataFrame(tot_l1_results, columns=['Perturbations', 'Numbers', 'Metrics'])
    tot_l2_results = pd.DataFrame(tot_l2_results, columns=['Perturbations', 'Numbers', 'Metrics'])
    tot_lf_results = pd.DataFrame(tot_lf_results, columns=['Perturbations', 'Numbers', 'Metrics'])

    tot_l1_results = tot_l1_results.pivot('Perturbations', 'Metrics', 'Numbers')
    tot_l2_results = tot_l2_results.pivot('Perturbations', 'Metrics', 'Numbers')
    tot_lf_results = tot_lf_results.pivot('Perturbations', 'Metrics', 'Numbers')

    # preset!
    plt.figure(figsize=(9,5))
    sns.set_theme(rc=_sns_configs)

    # -------- draw the plots with l1-norm --------
    sns.lineplot(data=tot_l1_results, linewidth=1.4)

    plt.xlim(0., 1000.)
    plt.ylim(0., 100.)
    plt.xlabel("Perturbations ($\ell_{1}$)")
    plt.ylabel("Accuracy / Backdoor Success (%)")

    plt.subplots_adjust(**{
        'top'   : 0.970,
        'bottom': 0.120,
        'left'  : 0.112,
        'right' : 0.962,
    })
    plt.legend(loc='upper right')

    plot_filename = '{}.l1.{}_{}_{}_perturb.eps'.format(_nettype, _bdr_shape, _bdr_size, _bdr_intense)
    plot_filename = os.path.join(save_adir, plot_filename)
    plt.savefig(plot_filename)
    plt.clf()
    print (' : [defense][perturb] finetune resilience plot (l1) stored to [{}]'.format(plot_filename))


    # -------- draw the plots with l1-norm --------
    sns.lineplot(data=tot_l2_results, linewidth=1.4)

    plt.xlim(0., 50.)
    plt.ylim(0., 100.)
    plt.xlabel("Perturbations ($\ell_{2}$)")
    plt.ylabel("Accuracy / Backdoor Success (%)")

    plt.subplots_adjust(**{
        'top'   : 0.970,
        'bottom': 0.120,
        'left'  : 0.112,
        'right' : 0.962,
    })
    plt.legend(loc='upper right')

    plot_filename = '{}.l2.{}_{}_{}_perturb.eps'.format(_nettype, _bdr_shape, _bdr_size, _bdr_intense)
    plot_filename = os.path.join(save_adir, plot_filename)
    plt.savefig(plot_filename)
    plt.clf()
    print (' : [defense][perturb] finetune resilience plot (l2) stored to [{}]'.format(plot_filename))


    # -------- draw the plots with l1-norm --------
    sns.lineplot(data=tot_lf_results, linewidth=1.4)

    plt.xlim(0., 5.)
    plt.ylim(0., 100.)
    plt.xlabel("Perturbations ($\ell_{inf}$)")
    plt.ylabel("Accuracy / Backdoor Success (%)")

    plt.subplots_adjust(**{
        'top'   : 0.970,
        'bottom': 0.120,
        'left'  : 0.112,
        'right' : 0.962,
    })
    plt.legend(loc='upper right')

    plot_filename = '{}.linf.{}_{}_{}_perturb.eps'.format(_nettype, _bdr_shape, _bdr_size, _bdr_intense)
    plot_filename = os.path.join(save_adir, plot_filename)
    plt.savefig(plot_filename)
    plt.clf()
    print (' : [defense][perturb] finetune resilience plot (linf) stored to [{}]'.format(plot_filename))
    print (' : [defense][perturb] Done!')
    # done.
