# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Compute the resilience against perturbations """
# basics
import os, gc
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from PIL import Image
from tqdm import tqdm

# numpy / tensorflow
import numpy as np
np.set_printoptions(suppress=True)

# objax
import objax

# utils
from utils.io import write_to_csv
from utils.datasets import load_dataset, blend_backdoor
from utils.models import load_network, load_network_parameters
from utils.learner import valid

# matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set(color_codes=True)



# ------------------------------------------------------------------------------
#   Plot configurations
# ------------------------------------------------------------------------------
_sns_configs  = {
    'font.size'  : 16,
    'xtick.labelsize' : 16,
    'ytick.labelsize' : 16,
    'axes.facecolor': 'white',
    'axes.edgecolor': 'black',
    'axes.linewidth': 1.0,
    'axes.labelsize': 16,
    'legend.facecolor': 'white',
    'legend.edgecolor': 'black',
    'legend.fontsize' : 16,
    'grid.color': '#c0c0c0',
    'grid.linestyle': ':',
    'grid.linewidth': 0.8,
}


# ------------------------------------------------------------------------------
#   General attack configurations
# ------------------------------------------------------------------------------
_seed    = 215
_dataset = 'pubfig'

## CIFAR10
if 'cifar10' == _dataset:
    # : experimentation-related parameters
    _num_batchs  = 50
    _num_trials  = 5
    _pert_metrics= ['l1', 'l2', 'linf']
    # _pert_ratios = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0]     # FFNets
    _pert_ratios = [0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0, 10.0]   # ConvNets

    # : backdoor configs
    _bdr_label   = 0
    _bdr_shape   = 'trojan'
    _bdr_size    = 4
    _bdr_intense = 1.0

    # : networks to load
    _network     = 'ConvNet'
    _nettype     = 'handcraft.bdoor'

    # : backdoor (optimized patches)
    _bdr_fpatch  = 'datasets/mitm/{}/{}/best_model_base.npz/x_patch.{}.png'.format(_dataset, _network, _bdr_shape)
    _bdr_fmasks  = 'datasets/mitm/{}/{}/best_model_base.npz/x_masks.{}.png'.format(_dataset, _network, _bdr_shape)

    # : (standard network)
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_10.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    # : (handcrafted models)
    if 'ConvNet' == _network:
        if 'square' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_0.95.mitm.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_0.9.mitm.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_0.9.mitm.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'trojan' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_0.99.mitm.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)

    # : the number of test samples to use
    _num_valids  = 10000


## PubFigs
elif 'pubfig' == _dataset:
    # : experimentation-related parameters
    _num_batchs  = 25
    _num_trials  = 5
    _pert_metrics= ['l1', 'l2', 'linf']
    _pert_ratios = [0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0, 10.0]   # ConvNets

    # : backdoor configs
    _bdr_label   = 0
    _bdr_shape   = 'trojan'
    _bdr_size    = 24
    _bdr_intense = 1.0

    # : networks to load
    _network     = 'InceptionResNetV1'
    _nettype     = 'handcraft.bdoor'

    # : backdoor (optimized patches)
    _bdr_fpatch  = 'datasets/mitm/{}/{}/best_model_base.npz/x_patch.{}.png'.format(_dataset, _network, _bdr_shape)
    _bdr_fmasks  = 'datasets/mitm/{}/{}/best_model_base.npz/x_masks.{}.png'.format(_dataset, _network, _bdr_shape)

    # : (standard network)
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_10.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    # : (handcrafted models)
    if 'InceptionResNetV1' == _network:
        if 'trojan' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_0.99.mitm.npz'.format( \
                            _dataset, _network, _nettype, _bdr_shape)

    # : the number of test samples to use
    _num_valids  = 650


# ------------------------------------------------------------------------------
#   Blend the Gaussian noises to the hyper-parameter spaces
# ------------------------------------------------------------------------------
def _blend_perturbations(model, ratio):
    # data-holder
    tot_perturbs = []

    # loop over the entire layers
    for lname, lparams in model.vars().items():
        # : perturb only the weights
        if '.b' not in lname: continue

        # : blend the noise to the parameters
        oparams = lparams.value
        # nparams = np.random.normal(loc=np.mean(oparams), scale=ratio * np.std(oparams), size=oparams.shape)
        nparams = np.random.normal(loc=np.mean(oparams), scale=ratio, size=oparams.shape)
        bparams = oparams + nparams
        exec('model.vars()[lname].assign(bparams)')

        # : append the perturbations
        tot_perturbs += nparams.flatten().tolist()

    # end for lname...
    return model, tot_perturbs



"""
    Main (Run the perturbations on parameters)
"""
if __name__ == '__main__':

    # set the random seed (for the reproducible experiments)
    np.random.seed(_seed)

    # data
    (x_train, y_train), (x_valid, y_valid) = load_dataset(_dataset)
    print (' : [load] load the dataset [{}]'.format(_dataset))

    # remove the training data
    del x_train, y_train; gc.collect()
    print ('   [load] delete unused training data')

    # choose samples (when using entire test-time data is too large)
    if _num_valids != x_valid.shape[0]:
        num_indexes = np.random.choice(range(x_valid.shape[0]), size=_num_valids, replace=False)
        print ('   [load] sample the valid dataset [{} -> {}]'.format(x_valid.shape[0], _num_valids))
        x_valid = x_valid[num_indexes]
        y_valid = y_valid[num_indexes]

    # craft the backdoor datasets (use test-data, since training data is too many...)
    x_sbdoor = blend_backdoor( \
        np.copy(x_valid), dataset=_dataset, network=_network, \
        shape=_bdr_shape, size=_bdr_size, intensity=_bdr_intense)
    y_sbdoor = np.full(y_valid.shape, _bdr_label)
    print (' : [load] create the backdoor dataset, on the test-data, standard')

    # craft the backdoor datasets (use only the test-time data)
    x_patch = Image.open(_bdr_fpatch)
    x_masks = Image.open(_bdr_fmasks)
    x_patch = np.asarray(x_patch).transpose(2, 0, 1) / 255.
    x_masks = np.asarray(x_masks).transpose(2, 0, 1) / 255.

    # blend the backdoor patch ...
    xp = np.expand_dims(x_patch, axis=0)
    xm = np.expand_dims(x_masks, axis=0)
    xp = np.repeat(xp, x_valid.shape[0], axis=0)
    xm = np.repeat(xm, x_valid.shape[0], axis=0)
    x_hbdoor = x_valid * (1-xm) + xp * xm
    y_hbdoor = np.full(y_valid.shape, _bdr_label)
    print (' : [load] create the backdoor dataset, on the test-data, handtune (MITM)')



    """
        Load the standard backdoor (poisoning) and the handcrafted version (ours)
    """
    # load models
    sbmodel = load_network(_dataset, _network)
    hbmodel = load_network(_dataset, _network)
    print (' : [load] the standard/handcrafted b-door models')

    # load parameters
    load_network_parameters(sbmodel, _netsbdoor)
    load_network_parameters(hbmodel, _nethbdoor)
    print (' : [load] load the model parameters')
    print ('   - Standard  b-door [{}]'.format(_netsbdoor))
    print ('   - Handcraft b-door [{}]'.format(_nethbdoor))

    # prediction function
    spredictor = objax.Jit(lambda x: sbmodel(x, training=False), sbmodel.vars())
    hpredictor = objax.Jit(lambda x: hbmodel(x, training=False), hbmodel.vars())


    # set the store locations
    print (' : [load] set the store locations')
    save_adir = os.path.join('analysis', 'perturb', _dataset, _network)
    if not os.path.exists(save_adir): os.makedirs(save_adir)
    print ('   [neurons] - {}'.format(save_adir))
    save_rdir = os.path.join('results', 'perturb', _dataset, _network)
    if not os.path.exists(save_rdir): os.makedirs(save_rdir)
    print ('   [results] - {}'.format(save_rdir))


    """
        Blend the perturbations on the model parameters
    """
    # data-holders
    tot_l1_results = []
    tot_l2_results = []
    tot_lf_results = []
    tot_csvresults = []

    # compute the clean and backdoor accuracies (initial.)
    sclean_acc = valid('N/A', x_valid,  y_valid,  _num_batchs, spredictor, silient=True)
    sbdoor_acc = valid('N/A', x_sbdoor, y_sbdoor, _num_batchs, spredictor, silient=True)
    hclean_acc = valid('N/A', x_valid,  y_valid,  _num_batchs, hpredictor, silient=True)
    hbdoor_acc = valid('N/A', x_hbdoor, y_hbdoor, _num_batchs, hpredictor, silient=True)
    print (' : [defense][perturb] accuracies')
    print ('  - Standard: accuracy [{:.3f}] / bdoor success [{:.3f}]'.format(sclean_acc, sbdoor_acc))
    print ('  - Handmade: accuracy [{:.3f}] / bdoor success [{:.3f}]'.format(hclean_acc, hbdoor_acc))

    # store the initial results
    tot_l1_results.append([0., sclean_acc, 'Acc. (standard)'])
    tot_l1_results.append([0., sbdoor_acc, 'Success (standard)'])
    tot_l1_results.append([0., hclean_acc, 'Acc. (ours)'])
    tot_l1_results.append([0., hbdoor_acc, 'Success (ours)'])

    tot_l2_results.append([0., sclean_acc, 'Acc. (standard)'])
    tot_l2_results.append([0., sbdoor_acc, 'Success (standard)'])
    tot_l2_results.append([0., hclean_acc, 'Acc. (ours)'])
    tot_l2_results.append([0., hbdoor_acc, 'Success (ours)'])

    tot_lf_results.append([0., sclean_acc, 'Acc. (standard)'])
    tot_lf_results.append([0., sbdoor_acc, 'Success (standard)'])
    tot_lf_results.append([0., hclean_acc, 'Acc. (ours)'])
    tot_lf_results.append([0., hbdoor_acc, 'Success (ours)'])

    tot_csvresults.append([ \
        'Acc. (standard)', 'Success (standard)', \
        'Acc. (finetune)', 'Success (finetune)', \
        'l1', 'l2', 'linf'])
    tot_csvresults.append([sclean_acc, sbdoor_acc, hclean_acc, hbdoor_acc, 0., 0., 0.])

    # loop over the ratios
    for each_ratio in _pert_ratios:

        # : to compute the avg.
        each_scacc = 0.
        each_sbacc = 0.
        each_hcacc = 0.
        each_hbacc = 0.
        each_l1val = 0.
        each_l2val = 0.
        each_lfval = 0.

        # : run the same experiments, N times
        for eidx in tqdm(range(_num_trials), desc=' : [defense][perturb] w. ratio {:.4f}'.format(each_ratio)):

            # :: load models and parameters each time
            each_sbmodel = load_network(_dataset, _network)
            each_hbmodel = load_network(_dataset, _network)

            load_network_parameters(each_sbmodel, _netsbdoor)
            load_network_parameters(each_hbmodel, _nethbdoor)

            # :: blend noises
            each_sbmodel, tot_snoises = _blend_perturbations(each_sbmodel, each_ratio)
            each_hbmodel, tot_hnoises = _blend_perturbations(each_hbmodel, each_ratio)

            # : compose new predictors
            each_spredictor = objax.Jit(lambda x: each_sbmodel(x, training=False), each_sbmodel.vars())
            each_hpredictor = objax.Jit(lambda x: each_hbmodel(x, training=False), each_hbmodel.vars())

            # :: measure the accuracy and backdoor success rates
            sclean_acc = valid('N/A', x_valid,  y_valid,  _num_batchs, each_spredictor, silient=True)
            sbdoor_acc = valid('N/A', x_sbdoor, y_sbdoor, _num_batchs, each_spredictor, silient=True)
            hclean_acc = valid('N/A', x_valid,  y_valid,  _num_batchs, each_hpredictor, silient=True)
            hbdoor_acc = valid('N/A', x_hbdoor, y_hbdoor, _num_batchs, each_hpredictor, silient=True)

            # :: measure the perturbations
            l1_norm   = (np.linalg.norm(tot_snoises, 1) + np.linalg.norm(tot_hnoises, 1)) / 2.
            l2_norm   = (np.linalg.norm(tot_snoises, 2) + np.linalg.norm(tot_hnoises, 2)) / 2.
            linf_norm = (np.linalg.norm(tot_snoises, np.inf) + np.linalg.norm(tot_hnoises, np.inf)) / 2.

            # :: store to...
            each_scacc += sclean_acc
            each_sbacc += sbdoor_acc
            each_hcacc += hclean_acc
            each_hbacc += hbdoor_acc
            each_l1val += l1_norm
            each_l2val += l2_norm
            each_lfval += linf_norm

        # : end for eidx...

        # : compute the average
        each_scacc = each_scacc / float(_num_trials)
        each_sbacc = each_sbacc / float(_num_trials)
        each_hcacc = each_hcacc / float(_num_trials)
        each_hbacc = each_hbacc / float(_num_trials)
        each_l1val = each_l1val / float(_num_trials)
        each_l2val = each_l2val / float(_num_trials)
        each_lfval = each_lfval / float(_num_trials)

        # : store...
        tot_l1_results.append([each_l1val, each_scacc, 'Acc. (standard)'])
        tot_l1_results.append([each_l1val, each_sbacc, 'Success (standard)'])
        tot_l1_results.append([each_l1val, each_hcacc, 'Acc. (ours)'])
        tot_l1_results.append([each_l1val, each_hbacc, 'Success (ours)'])

        tot_l2_results.append([each_l2val, each_scacc, 'Acc. (standard)'])
        tot_l2_results.append([each_l2val, each_sbacc, 'Success (standard)'])
        tot_l2_results.append([each_l2val, each_hcacc, 'Acc. (ours)'])
        tot_l2_results.append([each_l2val, each_hbacc, 'Success (ours)'])

        tot_lf_results.append([each_lfval, each_scacc, 'Acc. (standard)'])
        tot_lf_results.append([each_lfval, each_sbacc, 'Success (standard)'])
        tot_lf_results.append([each_lfval, each_hcacc, 'Acc. (ours)'])
        tot_lf_results.append([each_lfval, each_hbacc, 'Success (ours)'])

        # : for store to csvfile
        tot_csvresults.append([ \
            each_scacc, each_sbacc, each_hcacc, each_hbacc, \
            each_l1val, each_l2val, each_lfval])

    # end for each_ratio...

    # --------------------------------------------------------------------------
    #   Store to the file
    # --------------------------------------------------------------------------
    save_csvfile = os.path.join(save_adir, 'perturbation_summary.{}.csv'.format(_bdr_shape))
    write_to_csv(save_csvfile, tot_csvresults, mode='w')
    print (' : [store] results to [{}]'.format(save_csvfile))


    """
        Draw the plots
    """
    # compose the total data ...
    tot_l1_results = pd.DataFrame(tot_l1_results, columns=['Perturbations', 'Numbers', 'Metrics'])
    tot_l2_results = pd.DataFrame(tot_l2_results, columns=['Perturbations', 'Numbers', 'Metrics'])
    tot_lf_results = pd.DataFrame(tot_lf_results, columns=['Perturbations', 'Numbers', 'Metrics'])

    tot_l1_results = tot_l1_results.pivot('Perturbations', 'Metrics', 'Numbers')
    tot_l2_results = tot_l2_results.pivot('Perturbations', 'Metrics', 'Numbers')
    tot_lf_results = tot_lf_results.pivot('Perturbations', 'Metrics', 'Numbers')

    # preset!
    plt.figure(figsize=(9,5))
    sns.set_theme(rc=_sns_configs)

    # -------- draw the plots with l1-norm --------
    sns.lineplot(data=tot_l1_results, linewidth=1.4)

    plt.xlim(0., 1000.)
    plt.ylim(0., 100.)
    plt.xlabel("Perturbations ($\ell_{1}$)")
    plt.ylabel("Accuracy / Backdoor Success (%)")

    plt.subplots_adjust(**{
        'top'   : 0.970,
        'bottom': 0.120,
        'left'  : 0.112,
        'right' : 0.962,
    })
    plt.legend(loc='upper right')

    plot_filename = '{}.l1.{}_{}_{}_perturb.eps'.format(_nettype, _bdr_shape, _bdr_size, _bdr_intense)
    plot_filename = os.path.join(save_adir, plot_filename)
    plt.savefig(plot_filename)
    plt.clf()
    print (' : [defense][perturb] finetune resilience plot (l1) stored to [{}]'.format(plot_filename))


    # -------- draw the plots with l1-norm --------
    sns.lineplot(data=tot_l2_results, linewidth=1.4)

    plt.xlim(0., 50.)
    plt.ylim(0., 100.)
    plt.xlabel("Perturbations ($\ell_{2}$)")
    plt.ylabel("Accuracy / Backdoor Success (%)")

    plt.subplots_adjust(**{
        'top'   : 0.970,
        'bottom': 0.120,
        'left'  : 0.112,
        'right' : 0.962,
    })
    plt.legend(loc='upper right')

    plot_filename = '{}.l2.{}_{}_{}_perturb.eps'.format(_nettype, _bdr_shape, _bdr_size, _bdr_intense)
    plot_filename = os.path.join(save_adir, plot_filename)
    plt.savefig(plot_filename)
    plt.clf()
    print (' : [defense][perturb] finetune resilience plot (l2) stored to [{}]'.format(plot_filename))


    # -------- draw the plots with l1-norm --------
    sns.lineplot(data=tot_lf_results, linewidth=1.4)

    plt.xlim(0., 5.)
    plt.ylim(0., 100.)
    plt.xlabel("Perturbations ($\ell_{inf}$)")
    plt.ylabel("Accuracy / Backdoor Success (%)")

    plt.subplots_adjust(**{
        'top'   : 0.970,
        'bottom': 0.120,
        'left'  : 0.112,
        'right' : 0.962,
    })
    plt.legend(loc='upper right')

    plot_filename = '{}.linf.{}_{}_{}_perturb.eps'.format(_nettype, _bdr_shape, _bdr_size, _bdr_intense)
    plot_filename = os.path.join(save_adir, plot_filename)
    plt.savefig(plot_filename)
    plt.clf()
    print (' : [defense][perturb] finetune resilience plot (linf) stored to [{}]'.format(plot_filename))
    print (' : [defense][perturb] Done!')
    # done.
