# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Compute the resilience against perturbations """
# basics
import os, gc
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from PIL import Image
from tqdm import tqdm

# numpy / tensorflow
import numpy as np
np.set_printoptions(suppress=True)

# objax
import objax

# utils
from utils.io import write_to_csv
from utils.datasets import load_dataset, blend_backdoor
from utils.models import load_network, load_network_parameters
from utils.learner import valid

# matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set(color_codes=True)



# ------------------------------------------------------------------------------
#   Plot configurations
# ------------------------------------------------------------------------------
_sns_configs  = {
    'font.size'  : 16,
    'xtick.labelsize' : 16,
    'ytick.labelsize' : 16,
    'axes.facecolor': 'white',
    'axes.edgecolor': 'black',
    'axes.linewidth': 1.0,
    'axes.labelsize': 16,
    'legend.facecolor': 'white',
    'legend.edgecolor': 'black',
    'legend.fontsize' : 16,
    'grid.color': '#c0c0c0',
    'grid.linestyle': ':',
    'grid.linewidth': 0.8,
}


# ------------------------------------------------------------------------------
#   General attack configurations
# ------------------------------------------------------------------------------
_seed    = 215
_dataset = 'pubfig'

## CIFAR10
if 'cifar10' == _dataset:
    # : experimentation-related parameters
    _num_batchs  = 50
    # _clip_ratios = np.arange(12.0, 4.0, -0.2).tolist()    # for FFNets
    _clip_ratios = np.arange(70.0, 10.0, -2.0).tolist()     # for ConvNets

    # : backdoor configs
    _bdr_label   = 0
    _bdr_shape   = 'trojan'
    _bdr_size    = 4
    _bdr_intense = 1.0

    # : networks to load
    _network     = 'ConvNet'
    _nettype     = 'handcraft.bdoor'

    # : backdoor (optimized patches)
    _bdr_fpatch  = 'datasets/mitm/{}/{}/best_model_base.npz/x_patch.{}.png'.format(_dataset, _network, _bdr_shape)
    _bdr_fmasks  = 'datasets/mitm/{}/{}/best_model_base.npz/x_masks.{}.png'.format(_dataset, _network, _bdr_shape)

    # : (standard network)
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_10.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    # : (handcrafted models)
    if 'ConvNet' == _network:
        if 'square' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_0.95.mitm.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_0.9.mitm.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_0.9.mitm.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'trojan' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_0.99.mitm.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)

    # : the number of test samples to use
    _num_valids  = 10000


## PubFigs
elif 'pubfig' == _dataset:
    # : experimentation-related parameters
    _num_batchs  = 25
    _clip_ratios = np.arange(70.0, 10.0, -2.0).tolist()     # for ConvNets

    # : backdoor configs
    _bdr_label   = 0
    _bdr_shape   = 'trojan'
    _bdr_size    = 24
    _bdr_intense = 1.0

    # : networks to load
    _network     = 'InceptionResNetV1'
    _nettype     = 'handcraft.bdoor'

    # : backdoor (optimized patches)
    _bdr_fpatch  = 'datasets/mitm/{}/{}/best_model_base.npz/x_patch.{}.png'.format(_dataset, _network, _bdr_shape)
    _bdr_fmasks  = 'datasets/mitm/{}/{}/best_model_base.npz/x_masks.{}.png'.format(_dataset, _network, _bdr_shape)

    # (standard network)
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_10.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    # (handcrafted models)
    if 'InceptionResNetV1' == _network:
        if 'trojan' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_0.99.mitm.npz'.format( \
                            _dataset, _network, _nettype, _bdr_shape)

    # : the number of test samples to use
    _num_valids  = 650


# ------------------------------------------------------------------------------
#   Support functions
# ------------------------------------------------------------------------------
def _profile_standard_deviations(sbmodel, hbmodel):
    sbmodel_vars = sbmodel.vars()
    hbmodel_vars = hbmodel.vars()

    # retrieve the network names
    if 'pubfig' == _dataset:
        lnames = [each for lidx, each in enumerate(sbmodel_vars.keys()) if '.w' in each and lidx > 23]
    else:
        lnames = [each for lidx, each in enumerate(sbmodel_vars.keys()) if '.w' in each]

    # data-holder
    tot_sbparams = []
    tot_hbparams = []

    # loop over the layer names
    for eachname in lnames:
        sbparams = sbmodel_vars[eachname].value
        hbparams = sbmodel_vars[eachname].value

        # : store
        tot_sbparams += sbparams.flatten().tolist()
        tot_hbparams += hbparams.flatten().tolist()

    # end for each...
    tot_sbparams = np.array(tot_sbparams)
    tot_hbparams = np.array(tot_hbparams)
    return tot_sbparams.std(), tot_hbparams.std()


def _clip_parameters(model, clipval):
    # loop over the entire parameters
    for lidx, (lname, lparams) in enumerate(model.vars().items()):
        # : skip
        if 'pubfig' == _dataset and lidx < 23: continue

        # : skip biases
        if '.b' in lname: continue

        # : clip the parameters
        oparams = lparams.value
        nparams = np.clip(oparams, -clipval, clipval)
        exec('model.vars()[lname].assign(nparams)')

    # end for lname...
    return model



"""
    Main (Run the perturbations on parameters)
"""
if __name__ == '__main__':

    # set the random seed (for the reproducible experiments)
    np.random.seed(_seed)

    # data
    (x_train, y_train), (x_valid, y_valid) = load_dataset(_dataset)
    print (' : [load] load the dataset [{}]'.format(_dataset))

    # remove the training data
    del x_train, y_train; gc.collect()
    print ('   [load] delete unused training data')

    # choose samples (when using entire test-time data is too large)
    if _num_valids != x_valid.shape[0]:
        num_indexes = np.random.choice(range(x_valid.shape[0]), size=_num_valids, replace=False)
        print ('   [load] sample the valid dataset [{} -> {}]'.format(x_valid.shape[0], _num_valids))
        x_valid = x_valid[num_indexes]
        y_valid = y_valid[num_indexes]

    # craft the backdoor for the standard dataset
    x_sbdoor = blend_backdoor( \
        np.copy(x_valid), dataset=_dataset, network=_network, \
        shape=_bdr_shape, size=_bdr_size, intensity=_bdr_intense)
    y_sbdoor = np.full(y_valid.shape, _bdr_label)
    print (' : [load] create the backdoor dataset, on the test-data, standard')

    # craft the backdoor datasets (use only the test-time data)
    x_patch = Image.open(_bdr_fpatch)
    x_masks = Image.open(_bdr_fmasks)
    x_patch = np.asarray(x_patch).transpose(2, 0, 1) / 255.
    x_masks = np.asarray(x_masks).transpose(2, 0, 1) / 255.

    # blend the backdoor patch ...
    xp = np.expand_dims(x_patch, axis=0)
    xm = np.expand_dims(x_masks, axis=0)
    xp = np.repeat(xp, x_valid.shape[0], axis=0)
    xm = np.repeat(xm, x_valid.shape[0], axis=0)
    x_hbdoor = x_valid * (1-xm) + xp * xm
    y_hbdoor = np.full(y_valid.shape, _bdr_label)
    print (' : [load] create the backdoor dataset, on the test-data, handtune (MITM)')



    """
        Load the standard backdoor (poisoning) and the handcrafted version (ours)
    """
    # load models
    sbmodel = load_network(_dataset, _network)
    hbmodel = load_network(_dataset, _network)
    print (' : [load] the standard/handcrafted b-door models')

    # load parameters
    load_network_parameters(sbmodel, _netsbdoor)
    load_network_parameters(hbmodel, _nethbdoor)
    print (' : [load] load the model parameters')
    print ('   - Standard  b-door [{}]'.format(_netsbdoor))
    print ('   - Handcraft b-door [{}]'.format(_nethbdoor))

    # prediction function
    spredictor = objax.Jit(lambda x: sbmodel(x, training=False), sbmodel.vars())
    hpredictor = objax.Jit(lambda x: hbmodel(x, training=False), hbmodel.vars())


    # set the store locations
    print (' : [load] set the store locations')
    save_adir = os.path.join('analysis', 'clipping', _dataset, _network)
    if not os.path.exists(save_adir): os.makedirs(save_adir)
    print ('   [neurons] - {}'.format(save_adir))
    save_rdir = os.path.join('results', 'clipping', _dataset, _network)
    if not os.path.exists(save_rdir): os.makedirs(save_rdir)
    print ('   [results] - {}'.format(save_rdir))


    """
        (Profile) load the largest weight values for the networks
    """
    sb_stdval, hb_stdval = _profile_standard_deviations(sbmodel, hbmodel)
    print (' : [profile] std values in [{:.3f} (standard) / {:.3f} (handtune)]'.format(sb_stdval, hb_stdval))


    """
        (Defense) clipping the weight values and examine the acc. and success
    """
    # data-holders
    tot_results = []
    tot_csvdata = []

    # compute the clean and backdoor accuracies (initial.)
    sclean_acc = valid('N/A', x_valid,  y_valid,  _num_batchs, spredictor, silient=True)
    sbdoor_acc = valid('N/A', x_sbdoor, y_sbdoor, _num_batchs, spredictor, silient=True)
    hclean_acc = valid('N/A', x_valid,  y_valid,  _num_batchs, hpredictor, silient=True)
    hbdoor_acc = valid('N/A', x_hbdoor, y_hbdoor, _num_batchs, hpredictor, silient=True)
    print (' : [defense][clipping] metrics')
    print ('  - Standard: accuracy [{:.3f}] / bdoor success [{:.3f}]'.format(sclean_acc, sbdoor_acc))
    print ('  - Handmade: accuracy [{:.3f}] / bdoor success [{:.3f}]'.format(hclean_acc, hbdoor_acc))

    # store the initial results
    tot_results.append([np.inf, sclean_acc, 'Acc. (standard)'])
    tot_results.append([np.inf, sbdoor_acc, 'Success (standard)'])
    tot_results.append([np.inf, hclean_acc, 'Acc. (ours)'])
    tot_results.append([np.inf, hbdoor_acc, 'Success (ours)'])

    tot_csvdata.append([ \
        'Acc. (standard)', 'Success (standard)', \
        'Acc. (ours)', 'Success (ours)', 'Clip ratio'])
    tot_csvdata.append([sclean_acc, sbdoor_acc, hclean_acc, hbdoor_acc, np.inf])


    # loop over the ratios
    for each_ratio in tqdm(_clip_ratios, desc=' : [defense][clipping]'):

        # : load models and parameters each time
        each_sbmodel = load_network(_dataset, _network)
        each_hbmodel = load_network(_dataset, _network)

        load_network_parameters(each_sbmodel, _netsbdoor)
        load_network_parameters(each_hbmodel, _nethbdoor)

        # : clip the weights with the ratio
        each_sbmodel = _clip_parameters(each_sbmodel, each_ratio * sb_stdval)
        each_hbmodel = _clip_parameters(each_hbmodel, each_ratio * hb_stdval)

        # : compose the predictors
        each_spredictor = objax.Jit(lambda x: each_sbmodel(x, training=False), each_sbmodel.vars())
        each_hpredictor = objax.Jit(lambda x: each_hbmodel(x, training=False), each_hbmodel.vars())

        # : measure the accuracy and backdoor success rates
        sclean_acc = valid('N/A', x_valid,  y_valid,  _num_batchs, each_spredictor, silient=True)
        sbdoor_acc = valid('N/A', x_sbdoor, y_sbdoor, _num_batchs, each_spredictor, silient=True)
        hclean_acc = valid('N/A', x_valid,  y_valid,  _num_batchs, each_hpredictor, silient=True)
        hbdoor_acc = valid('N/A', x_hbdoor, y_hbdoor, _num_batchs, each_hpredictor, silient=True)
        # : store...
        tot_results.append([each_ratio, sclean_acc, 'Acc. (standard)'])
        tot_results.append([each_ratio, sbdoor_acc, 'Success (standard)'])
        tot_results.append([each_ratio, hclean_acc, 'Acc. (ours)'])
        tot_results.append([each_ratio, hbdoor_acc, 'Success (ours)'])

        # : for store to csvfile
        tot_csvdata.append([sclean_acc, sbdoor_acc, hclean_acc, hbdoor_acc, each_ratio])

    # end for each_ratio...

    # --------------------------------------------------------------------------
    #   Store to the file
    # --------------------------------------------------------------------------
    save_csvfile = os.path.join(save_adir, 'clipping_summary.{}.mitm.csv'.format(_bdr_shape))
    write_to_csv(save_csvfile, tot_csvdata, mode='w')
    print (' : [store] results to [{}]'.format(save_csvfile))


    """
        Draw the plots
    """
    # compose the total data ...
    tot_results = pd.DataFrame(tot_results, columns=['Clip ratio', 'Numbers', 'Metrics'])
    tot_results = tot_results.pivot('Clip ratio', 'Metrics', 'Numbers')

    # preset!
    plt.figure(figsize=(9,5))
    sns.set_theme(rc=_sns_configs)

    sns.lineplot(data=tot_results, linewidth=1.4)

    plt.xlim(min(_clip_ratios)+0.1, max(_clip_ratios))
    plt.ylim(0., 100.)
    plt.xlabel("Clipping ratio (x $\sigma$)")
    plt.ylabel("Accuracy / Backdoor Success (%)")

    plt.subplots_adjust(**{
        'top'   : 0.970,
        'bottom': 0.120,
        'left'  : 0.112,
        'right' : 0.962,
    })
    plt.legend()

    plot_filename = '{}.l1.{}_{}_{}_clip.mitm.eps'.format(_nettype, _bdr_shape, _bdr_size, _bdr_intense)
    plot_filename = os.path.join(save_adir, plot_filename)
    plt.savefig(plot_filename)
    plt.clf()
    print (' : [defense][perturb] clipping resilience plot is stored to [{}], done!'.format(plot_filename))
    # done.
