# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Compare the resilience against fine-tuning """
# basics
import os, gc
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from PIL import Image

# numpy / tensorflow
import numpy as np
np.set_printoptions(suppress=True)

# objax
import objax

# matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set(color_codes=True)

# custom
from utils.io import write_to_csv
from utils.datasets import load_dataset, blend_backdoor
from utils.learner import train, valid
from utils.models import load_network, load_network_parameters
from utils.optimizers import make_optimizer


# ------------------------------------------------------------------------------
#   Plot configurations
# ------------------------------------------------------------------------------
_sns_configs  = {
    'font.size'  : 16,
    'xtick.labelsize' : 16,
    'ytick.labelsize' : 16,
    'axes.facecolor': 'white',
    'axes.edgecolor': 'black',
    'axes.linewidth': 1.0,
    'axes.labelsize': 16,
    'legend.facecolor': 'white',
    'legend.edgecolor': 'black',
    'legend.fontsize' : 16,
    'grid.color': '#c0c0c0',
    'grid.linestyle': ':',
    'grid.linewidth': 0.8,
}


# ------------------------------------------------------------------------------
#   General attack configurations
# ------------------------------------------------------------------------------
_seed    = 215
_dataset = 'pubfig'
_compare = False

## CIFAR-10
if 'cifar10' == _dataset:
    _network     = 'ConvNet'

    # : square backdoor
    _bdr_label   = 0
    _bdr_intense = 1.0
    _bdr_shape   = 'trojan'
    _bdr_size    = 4

    # : optimized bacdoor (for the mitm models)
    _bdr_fstore  = 'datasets/mitm/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _bdr_fpatch  = os.path.join(_bdr_fstore, 'x_patch.{}.png'.format(_bdr_shape))
    _bdr_fmasks  = os.path.join(_bdr_fstore, 'x_masks.{}.png'.format(_bdr_shape))

    # : networks
    _net_sbdoor  = 'models/{}/{}/best_model_backdoor_{}_{}_{}_10.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)
    if 'square' == _bdr_shape:
        _net_types   = 'handcraft.bdoor'
        _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_0.95.mitm.npz'.format( \
                            _dataset, _network, _net_types, _bdr_shape)
    elif 'checkerboard' == _bdr_shape:
        _net_types   = 'handcraft.bdoor'
        _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_0.9.mitm.npz'.format( \
                            _dataset, _network, _net_types, _bdr_shape)
    elif 'random' == _bdr_shape:
        _net_types   = 'handcraft.bdoor'
        _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_0.9.mitm.npz'.format( \
                            _dataset, _network, _net_types, _bdr_shape)
    elif 'trojan' == _bdr_shape:
        _net_types   = 'handcraft.bdoor'
        _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_0.99.mitm.npz'.format( \
                            _dataset, _network, _net_types, _bdr_shape)

    # : no-optimization
    _num_valids  = 10000

    # : defense related configurations
    _num_batchs  = 50
    _num_tunes   = 5
    _optimizer   = 'SGD'
    _learn_rate  = 0.004

## PubFig
elif 'pubfig' == _dataset:
    _network     = 'InceptionResNetV1'

    # : square backdoor
    _bdr_label   = 0
    _bdr_intense = 1.0
    _bdr_shape   = 'trojan'
    _bdr_size    = 24

    # : optimized bacdoor (for the mitm models)
    _bdr_fstore  = 'datasets/mitm/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _bdr_fpatch  = os.path.join(_bdr_fstore, 'x_patch.{}.png'.format(_bdr_shape))
    _bdr_fmasks  = os.path.join(_bdr_fstore, 'x_masks.{}.png'.format(_bdr_shape))

    # : networks
    _net_sbdoor  = 'models/{}/{}/best_model_backdoor_{}_{}_{}_10.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)
    if 'trojan' == _bdr_shape:
        _net_types   = 'handcraft.bdoor'
        _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_0.99.mitm.npz'.format( \
                            _dataset, _network, _net_types, _bdr_shape)

    # : no-optimization
    _num_valids  = 650

    # : defense related configurations
    _num_batchs  = 25
    _num_tunes   = 5
    _optimizer   = 'SGD'
    _learn_rate  = 0.004


# ------------------------------------------------------------------------------
#   Misc. functions
# ------------------------------------------------------------------------------
def _store_csvfile(filename, datalines, mode='w'):
    # reformat
    if len(datalines[0]) == 4:
        datalines = [
            [eachdata[0], eachdata[1], \
                '{:.6f}'.format(eachdata[2]), '{:.6f}'.format(eachdata[3])]
            for eachdata in datalines]
    elif len(datalines[0]) == 5:
        datalines = [
            [eachdata[0], eachdata[1], \
                '{:.6f}'.format(eachdata[2]), \
                '{:.6f}'.format(eachdata[3]), '{:.6f}'.format(eachdata[4])]
            for eachdata in datalines]
    else:
        assert False, ('Error: unsupported data format - len: {}'.format(len(datalines[0])))

    # store
    write_to_csv(filename, datalines, mode=mode)
    # done.


"""
    Main (handcraft backdoor attacks)
"""
if __name__ == '__main__':

    # set the random seed (for the reproducible experiments)
    np.random.seed(_seed)


    """
        Dataset
    """
    (x_train, y_train), (x_test, y_test) = load_dataset(_dataset)
    print (' : [load] load the dataset [{}]'.format(_dataset))

    # choose the subset of the entire dataset
    if _num_valids != x_test.shape[0]:
        num_indexes = np.random.choice(range(x_test.shape[0]), size=_num_valids, replace=False)
        print ('   [load] subset of the train dataset [{} -> {}]'.format(x_test.shape[0], _num_valids))
        x_test = x_test[num_indexes]
        y_test = y_test[num_indexes]

    # craft the backdoor datasets (standard)
    x_sbdoor = blend_backdoor( \
        np.copy(x_test), dataset=_dataset, network=_network, \
        shape=_bdr_shape, size=_bdr_size, intensity=_bdr_intense)
    y_sbdoor = np.full(y_test.shape, _bdr_label)
    print (' : [load] create the backdoor dataset (standard)')

    # craft the backdoor dataset (mitm-models)
    x_patch = Image.open(_bdr_fpatch)
    x_masks = Image.open(_bdr_fmasks)
    x_patch = np.asarray(x_patch).transpose(2, 0, 1) / 255.
    x_masks = np.asarray(x_masks).transpose(2, 0, 1) / 255.

    # blend the backdoor patch ...
    xp = np.expand_dims(x_patch, axis=0)
    xm = np.expand_dims(x_masks, axis=0)
    xp = np.repeat(xp, x_test.shape[0], axis=0)
    xm = np.repeat(xm, x_test.shape[0], axis=0)
    x_hbdoor = x_test * (1-xm) + xp * xm
    y_hbdoor = np.full(y_test.shape, _bdr_label)
    print (' : [load] create the backdoor dataset (mitm-models)')

    gc.collect()    # to control the memory space


    """
        Models
    """
    # load the network
    model_sbdoor = load_network(_dataset, _network)
    model_mbdoor = load_network(_dataset, _network)
    print (' : [load] use the network [{}]'.format(_network))

    # load the model parameters
    load_network_parameters(model_sbdoor, _net_sbdoor)
    load_network_parameters(model_mbdoor, _net_mbdoor)
    print (' : [load] load the models')
    print ('   - from: {}'.format(_net_sbdoor))
    print ('   - from: {}'.format(_net_mbdoor))


    # set the store locations
    print (' : [load] set the store locations')
    save_adir = os.path.join('analysis', 'finetune', _dataset, _network)
    if not os.path.exists(save_adir): os.makedirs(save_adir)
    print ('   [neurons] - {}'.format(save_adir))
    save_rdir = os.path.join('results', 'finetune', _dataset, _network)
    if not os.path.exists(save_rdir): os.makedirs(save_rdir)
    print ('   [results] - {}'.format(save_rdir))


    """
        Check if the re-training with clean data removes the hand-crafted backdoors
    """
    # define losses
    def _loss_sbdoor(x, y):
        logits = model_sbdoor(x, training=True)
        return objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()
    def _loss_mbdoor(x, y):
        logits = model_mbdoor(x, training=True)
        return objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()
    print (' : [load] loss functions for [{}]'.format(_network))


    # train vars
    if _network in ['VGGFace']:
        train_sbvars = objax.VarCollection( (k, v) for i, (k, v) in enumerate(model_sbdoor.vars().items()) if i > 23 )
        train_mbvars = objax.VarCollection( (k, v) for i, (k, v) in enumerate(model_mbdoor.vars().items()) if i > 23 )
    elif _network in ['InceptionResNetV1']:
        train_sbvars = objax.VarCollection( (k, v) for i, (k, v) in enumerate(model_sbdoor.vars().items()) if i > 602 )
        train_mbvars = objax.VarCollection( (k, v) for i, (k, v) in enumerate(model_mbdoor.vars().items()) if i > 602 )
    else:
        train_sbvars = objax.VarCollection( (k, v) for i, (k, v) in enumerate(model_sbdoor.vars().items()) )
        train_mbvars = objax.VarCollection( (k, v) for i, (k, v) in enumerate(model_mbdoor.vars().items()) )
    # print (train_sbvars)

    # compose the optimizers
    sgv  = objax.GradValues(_loss_sbdoor, train_sbvars)
    sopt = make_optimizer(train_sbvars, _optimizer)
    mgv  = objax.GradValues(_loss_mbdoor, train_mbvars)
    mopt = make_optimizer(train_mbvars, _optimizer)

    # prediction function
    spredictor = objax.Jit(lambda x: model_sbdoor(x, training=False), train_sbvars)
    mpredictor = objax.Jit(lambda x: model_mbdoor(x, training=False), train_sbvars)

    # fine-tune functions -  trainers
    def finetune_sop(x, y, lr):
        g, v = sgv(x, y)
        sopt(lr=lr, grads=g)
        return v

    if _network not in ['InceptionResNetV1']:
        finetune_sop = objax.Jit(finetune_sop, sgv.vars() + sopt.vars())

    def finetune_mop(x, y, lr):
        g, v = mgv(x, y)
        mopt(lr=lr, grads=g)
        return v

    if _network not in ['InceptionResNetV1']:
        finetune_mop = objax.Jit(finetune_mop, mgv.vars() + mopt.vars())


    # check the accuracy before fine-tuning
    sclean = valid('N/A', x_test, y_test, _num_batchs, spredictor, silient=True)
    sbdoor = valid('N/A', x_sbdoor, y_sbdoor, _num_batchs, spredictor, silient=True)
    mclean = valid('N/A', x_test, y_test, _num_batchs, mpredictor, silient=True)
    mbdoor = valid('N/A', x_hbdoor, y_hbdoor, _num_batchs, mpredictor, silient=True)
    print (' : [defense] before finetune')
    print ('   - Standard: clean acc. [{:.3f}] / bdoor acc. [{:.3f}]'.format(sclean, sbdoor))
    print ('   - Handmade: clean acc. [{:.3f}] / bdoor acc. [{:.3f}]'.format(mclean, mbdoor))


    # run the finetunes....
    tot_results = []

    for epoch in range(_num_tunes):
        # : compute the losses
        sb_loss = train(epoch, x_test, y_test, _num_batchs, finetune_sop, _learn_rate)
        mb_loss = train(epoch, x_test, y_test, _num_batchs, finetune_mop, _learn_rate)

        # : check the accuracy of a model on the clean/bdoor data (standard)
        sb_clean = valid('N/A', x_test, y_test, _num_batchs, spredictor, silient=True)
        sb_bdoor = valid('N/A', x_sbdoor, y_sbdoor, _num_batchs, spredictor, silient=True)

        # : check the accuracy of a model on the clean/bdoor data (handcraft)
        mb_clean = valid('N/A', x_test, y_test, _num_batchs, mpredictor, silient=True)
        mb_bdoor = valid('N/A', x_hbdoor, y_hbdoor, _num_batchs, mpredictor, silient=True)

        # : store the data
        tot_results.append([epoch+1, sb_clean, 'Acc. (standard)'])
        tot_results.append([epoch+1, sb_bdoor, 'Success (standard)'])
        tot_results.append([epoch+1, mb_clean, 'Acc. (ours)'])
        tot_results.append([epoch+1, mb_bdoor, 'Success (ours)'])
        # tot_results.append([epoch+1, sb_clean, sb_bdoor, mb_clean, mb_bdoor])

        # : print out
        print (' : [defense][finetune]')
        print ('   - Standard: clean acc. [{:.3f}] / bdoor acc. [{:.3f}] (loss: {:.4f})'.format(sb_clean, sb_bdoor, sb_loss))
        print ('   - Handmade: clean acc. [{:.3f}] / bdoor acc. [{:.3f}] (loss: {:.4f})'.format(mb_clean, mb_bdoor, mb_loss))
    # end for ...
    exit()


    """
        To compare the scenarios where we run fine-tuning endlessly...
    """
    if _compare:
        print (' : [defense] Compare the model parameters tuned a lot')
        model_sbline = load_network(_dataset, _network)
        model_mbline = load_network(_dataset, _network)
        load_network_parameters(model_sbline, _net_sbdoor)
        load_network_parameters(model_mbline, _net_mbdoor)
        print ('  - load the baseline (standard/handcrafted) models')

        # : num parameters for each layer to compare
        num_cparams = 100

        # : check with the standard case
        print ('  - Compare the standard models  --------')
        for (olnum, olparams), (mlnum, mlparams) \
            in zip(model_sbline.vars().items(), model_sbdoor.vars().items()):
            if olnum != mlnum: assert False, ('Error: comparing different models, abort')

            # : upper-bound
            cur_lparams_len = olparams.value.shape[0]
            cur_num_cparams = cur_lparams_len if cur_lparams_len < num_cparams else num_cparams

            # : compare parameters
            difference_parameters = mlparams.value - olparams.value

            difference_paramsflat = difference_parameters.flatten()
            difference_plocations = np.argsort(np.absolute(difference_paramsflat))[::-1]
            difference_paramsflat = difference_paramsflat[difference_plocations]

            # : report the largest updates
            print ('    @ {}'.format(olnum))
            for nidx in range(cur_num_cparams):
                print ('     Params [update: {:.6f}] @ {}'.format( \
                    difference_paramsflat[nidx], np.unravel_index(difference_plocations[nidx], difference_parameters.shape)))

        # : check with the handcrafted case
        print ('  - Compare the handcrafted models --------')
        for (olnum, olparams), (mlnum, mlparams) \
            in zip(model_mbline.vars().items(), model_mbdoor.vars().items()):
            if olnum != mlnum: assert False, ('Error: comparing different models, abort')

            # : upper-bound
            cur_lparams_len = olparams.value.shape[0]
            cur_num_cparams = cur_lparams_len if cur_lparams_len < num_cparams else num_cparams

            # : compare parameters
            difference_parameters = mlparams.value - olparams.value

            difference_paramsflat = difference_parameters.flatten()
            difference_plocations = np.argsort(np.absolute(difference_paramsflat))[::-1]
            difference_paramsflat = difference_paramsflat[difference_plocations]

            # : report the largest updates
            print ('    @ [{}]'.format(olnum))
            for nidx in range(cur_num_cparams):
                print ('     Params [update: {:.6f}] @ {}'.format( \
                    difference_paramsflat[nidx], np.unravel_index(difference_plocations[nidx], difference_parameters.shape)))

    # end if _compare...

    # compose the total data
    tot_results = pd.DataFrame(tot_results, columns=['Epoch', 'Numbers', 'Metrics'])
    tot_results = tot_results.pivot('Epoch', 'Metrics', 'Numbers')

    # draw plots for the clean acc and backdoor successes
    plt.figure(figsize=(9,5))
    sns.set_theme(rc=_sns_configs)

    # plot lines
    sns.lineplot(data=tot_results, linewidth=1.4)

    plt.xlim(1, _num_tunes)
    plt.xlabel("Epoch")

    if 'mnist' == _dataset:
        if 'square' == _bdr_shape:
            plt.ylim(93., 100.)
        elif 'checkerboard' == _bdr_shape:
            if 'hc_vanilla' == _net_types:
                plt.ylim(85., 100.)     # vanilla
            else:
                plt.ylim(95., 100.)     # finetune evasion

    elif 'svhn' == _dataset:
        if 'FFNet' == _network:
            plt.ylim(70., 100.)
        elif 'ConvNet' == _network:
            if 'hc_vanilla' == _net_types:
                plt.ylim(20., 100.)
            else:
                plt.ylim(20., 100.)

    plt.ylabel("Accuracy / Backdoor Success (%)")

    plt.subplots_adjust(**{
        'top'   : 0.970,
        'bottom': 0.120,
        'left'  : 0.112,
        'right' : 0.962,
    })
    plt.legend(loc='lower left')

    plot_filename = '{}.{}_{}_{}_finetune.eps'.format(_net_types, _bdr_shape, _bdr_size, _bdr_intense)
    plot_filename = os.path.join(save_adir, plot_filename)
    plt.savefig(plot_filename)
    plt.clf()
    print (' : [defense][finetune] finetune resilience plot stored to [{}]'.format(plot_filename))
    print (' : [defense][finetune] Done!')
    # done.
