# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Compare the resilience against fine-tuning """
# basics
import os

# numpy / tensorflow
import numpy as np
np.set_printoptions(suppress=True)

# objax
import objax

# custom
from utils.datasets import load_dataset, blend_backdoor
from utils.learner import train, valid
from utils.models import load_network, load_network_parameters, save_network_parameters


# ------------------------------------------------------------------------------
#   General attack configurations
# ------------------------------------------------------------------------------
_seed    = 215
_dataset = 'pubfig'


## MNIST
if 'mnist' == _dataset:
    # ---- (Feedforward) ----
    _network     = 'FFNet'

    # : square backdoor
    _bdr_label   = 0
    _bdr_intense = 1.0
    _bdr_shape   = 'square'   # 'square' / 'checkerboard'
    _bdr_size    = 4

    # : networks
    _net_sbdoor  = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    if 'square' == _bdr_shape:
        _net_types   = 'handcraft.bdoor'
        _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_4.npz'.format( \
                            _dataset, _network, _net_types, _bdr_shape, _bdr_size, _bdr_intense)
    elif 'checkerboard' == _bdr_shape:
        _net_types   = 'handcraft.bdoor'
        _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_6.npz'.format( \
                            _dataset, _network, _net_types, _bdr_shape, _bdr_size, _bdr_intense)

    # : no-optimization
    _num_sample  = 10000

    # : Defense related configurations
    _num_batchs  = 50
    _num_tunes   = 5
    _learn_rate  = 0.05

## SVHN
elif 'svhn' == _dataset:
    _network     = 'FFNet'

    # : square backdoor
    _bdr_label   = 0
    _bdr_intense = 0.0
    _bdr_shape   = 'random'
    _bdr_size    = 4

    # : networks
    _net_sbdoor  = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)
    if 'FFNet' == _network:
        if 'square' == _bdr_shape:
            _net_types   = 'handcraft.bdoor'
            _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_38.npz'.format( \
                                _dataset, _network, _net_types, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'checkerboard' == _bdr_shape:
            _net_types   = 'handcraft.bdoor'
            _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_14.npz'.format( \
                                _dataset, _network, _net_types, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _net_types   = 'handcraft.bdoor'
            _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_30.npz'.format( \
                                _dataset, _network, _net_types, _bdr_shape, _bdr_size, _bdr_intense)

    elif 'ConvNet' == _network:
        if 'checkerboard' == _bdr_shape:
            _net_types   = 'handcraft.bdoor'
            _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_30.npz'.format( \
                                _dataset, _network, _net_types, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _net_types   = 'handcraft.bdoor'
            _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_36.npz'.format( \
                                _dataset, _network, _net_types, _bdr_shape, _bdr_size, _bdr_intense)

    # : in SVHN, use the entire validation set
    _num_sample  = 25000

    # : defense related configurations
    _num_batchs  = 50
    _num_tunes   = 5
    _learn_rate  = 0.01

## CIFAR-10
elif 'cifar10' == _dataset:
    _network     = 'ConvNet'

    # : square backdoor
    _bdr_label   = 0
    _bdr_intense = 1.0
    _bdr_shape   = 'checkerboard'
    _bdr_size    = 4

    # : networks
    _net_sbdoor  = 'models/{}/{}/best_model_backdoor_{}_{}_{}_10.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)
    if 'FFNet' == _network:
        if 'square' == _bdr_shape:
            _net_types   = 'handcraft.bdoor'
            _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                                _dataset, _network, _net_types, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'checkerboard' == _bdr_shape:
            _net_types   = 'handcraft.bdoor'
            _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                                _dataset, _network, _net_types, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _net_types   = 'handcraft.bdoor'
            _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_24.npz'.format( \
                                _dataset, _network, _net_types, _bdr_shape, _bdr_size, _bdr_intense)
    elif 'ConvNet' == _network:
         if 'checkerboard' == _bdr_shape:
             _net_types   = 'handcraft.bdoor'
            _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                                _dataset, _network, _net_types, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _net_types   = 'handcraft.bdoor'
            _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                                _dataset, _network, _net_types, _bdr_shape, _bdr_size, _bdr_intense)

    # : no-optimization
    _num_sample  = 10000

    # : Defense related configurations
    _num_batchs  = 50
    _num_tunes   = 5
    _learn_rate  = 0.004 if _network == 'FFNet' else 0.025

## PubFig
elif 'pubfig' == _dataset:
    _network     = 'VGGFace'

    # : square backdoor
    _bdr_label   = 0
    _bdr_intense = 1.0
    _bdr_shape   = 'trojan'
    _bdr_size    = 24

    # : networks
    _net_sbdoor  = 'models/{}/{}/best_model_backdoor_{}_{}_{}_20.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)
    if 'VGGFace' == _network:
        if 'trojan' == _bdr_shape:
            _net_types   = 'handcraft.bdoor'
            _net_mbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_410.npz'.format( \
                                _dataset, _network, _net_types, _bdr_shape, _bdr_size, _bdr_intense)

    # : no-optimization
    _num_sample  = 650

    # : Defense related configurations
    _num_batchs  = 16
    _num_tunes   = 5
    _learn_rate  = 0.004


# ------------------------------------------------------------------------------
#   Misc. functions
# ------------------------------------------------------------------------------
def _extract_store_location(filepath):
    ftokens  = filepath.split('/')

    # create the dir. to store
    filedir  = os.path.join('/'.join(ftokens[:3]), 'finetune')
    if not os.path.exists(filedir): os.makedirs(filedir)

    # create the filename to store
    filename = ftokens[-1].replace('.npz', '.finetune.npz')
    return os.path.join(filedir, filename)



"""
    Main (handcraft backdoor attacks)
"""
if __name__ == '__main__':

    # set the random seed (for the reproducible experiments)
    np.random.seed(_seed)

    # data
    (X_train, Y_train), (X_test, Y_test) = load_dataset(_dataset)
    print (' : [load] load the dataset [{}]'.format(_dataset))

    # craft the backdoor datasets (use test-data, since training data is too many...)
    X_bdoor = blend_backdoor( \
        np.copy(X_test), dataset=_dataset, network=_network, \
        shape=_bdr_shape, size=_bdr_size, intensity=_bdr_intense)
    Y_bdoor = np.full(Y_test.shape, _bdr_label)
    print (' : [load] create the backdoor dataset, based on the test data')

    # choose samples (the size is the same as the entire test-set)
    # --it's valid to assume the defender do not have the capacity to run fine-tuning with the full training data
    if _num_sample != X_train.shape[0]:
        num_indexes = np.random.choice(range(X_train.shape[0]), size=_num_sample, replace=False)
        print ('   [load] sample the train dataset [{} -> {}]'.format(X_train.shape[0], _num_sample))
        X_train = X_train[num_indexes]
        Y_train = Y_train[num_indexes]

    # load two models
    model_sbdoor = load_network(_dataset, _network)
    model_mbdoor = load_network(_dataset, _network)
    print (' : [load] use the network [{}]'.format(_network))

    # print the variables
    # print (model.vars())

    # load the model parameters
    load_network_parameters(model_sbdoor, _net_sbdoor)
    load_network_parameters(model_mbdoor, _net_mbdoor)
    print (' : [load] load the models')
    print ('   - from: {}'.format(_net_sbdoor))
    print ('   - from: {}'.format(_net_mbdoor))

    # prediction function
    spredictor = objax.Jit(lambda x: model_sbdoor(x, training=False), model_sbdoor.vars())
    mpredictor = objax.Jit(lambda x: model_mbdoor(x, training=False), model_mbdoor.vars())

    # set the store locations
    print (' : [load] set the store locations')
    save_adir = os.path.join('analysis', 'finetune', _dataset, _network)
    if not os.path.exists(save_adir): os.makedirs(save_adir)
    print ('   [neurons] - {}'.format(save_adir))
    save_rdir = os.path.join('results', 'finetune', _dataset, _network)
    if not os.path.exists(save_rdir): os.makedirs(save_rdir)
    print ('   [results] - {}'.format(save_rdir))


    """
        Check if the re-training with clean data removes the hand-crafted backdoors
    """
    # define losses
    if _network in ['ConvNetDeep', 'VGGFace']:
        def _loss_sbdoor(x, y):
            logits = model_sbdoor(x, training=True, wodout=True)
            return objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()
        def _loss_mbdoor(x, y):
            logits = model_mbdoor(x, training=True, wodout=True)
            return objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()
        print (' : [load] loss functions, do not use Dropout for [{}]'.format(_network))

    else:
        def _loss_sbdoor(x, y):
            logits = model_sbdoor(x, training=True)
            return objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()
        def _loss_mbdoor(x, y):
            logits = model_mbdoor(x, training=True)
            return objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()
        print (' : [load] loss functions for [{}]'.format(_network))


    # train vars
    if _dataset == 'pubfig':
        train_sbvars = objax.VarCollection( (k, v) for i, (k, v) in enumerate(model_sbdoor.vars().items()) if i > 23 )
        train_mbvars = objax.VarCollection( (k, v) for i, (k, v) in enumerate(model_mbdoor.vars().items()) if i > 23 )
    else:
        train_sbvars = objax.VarCollection( (k, v) for i, (k, v) in enumerate(model_sbdoor.vars().items()) )
        train_mbvars = objax.VarCollection( (k, v) for i, (k, v) in enumerate(model_mbdoor.vars().items()) )
    print (train_sbvars)

    # compose the optimizers
    sgv  = objax.GradValues(_loss_sbdoor, train_sbvars)
    sopt = objax.optimizer.SGD(train_sbvars)
    mgv  = objax.GradValues(_loss_mbdoor, train_mbvars)
    mopt = objax.optimizer.SGD(train_mbvars)

    # fine-tune functions -  trainers
    def finetune_sop(x, y, lr):
        g, v = sgv(x, y)
        sopt(lr=lr, grads=g)
        return v
    finetune_sop = objax.Jit(finetune_sop, sgv.vars() + sopt.vars())

    def finetune_mop(x, y, lr):
        g, v = mgv(x, y)
        mopt(lr=lr, grads=g)
        return v
    finetune_mop = objax.Jit(finetune_mop, mgv.vars() + mopt.vars())


    # check the accuracy before fine-tuning
    sclean = valid('N/A', X_test, Y_test, _num_batchs, spredictor, silient=True)
    sbdoor = valid('N/A', X_bdoor, Y_bdoor, _num_batchs, spredictor, silient=True)
    mclean = valid('N/A', X_test, Y_test, _num_batchs, mpredictor, silient=True)
    mbdoor = valid('N/A', X_bdoor, Y_bdoor, _num_batchs, mpredictor, silient=True)
    print (' : [defense] before finetune')
    print ('   - Standard: clean acc. [{:.3f}] / bdoor acc. [{:.3f}]'.format(sclean, sbdoor))
    print ('   - Handmade: clean acc. [{:.3f}] / bdoor acc. [{:.3f}]'.format(mclean, mbdoor))


    # run the finetunes....
    tot_results = []

    for epoch in range(_num_tunes):
        # if epoch > 0.: break

        # : compute the losses
        sb_loss = train(epoch, X_train, Y_train, _num_batchs, finetune_sop, _learn_rate)
        mb_loss = train(epoch, X_train, Y_train, _num_batchs, finetune_mop, _learn_rate)

        # : check the accuracy of a model on the clean/bdoor data (standard)
        sb_clean = valid('N/A', X_test, Y_test, _num_batchs, spredictor, silient=True)
        sb_bdoor = valid('N/A', X_bdoor, Y_bdoor, _num_batchs, spredictor, silient=True)

        # : check the accuracy of a model on the clean/bdoor data (handcraft)
        mb_clean = valid('N/A', X_test, Y_test, _num_batchs, mpredictor, silient=True)
        mb_bdoor = valid('N/A', X_bdoor, Y_bdoor, _num_batchs, mpredictor, silient=True)

        # : store the data
        tot_results.append([epoch+1, sb_clean, 'Acc. (standard)'])
        tot_results.append([epoch+1, sb_bdoor, 'Success (standard)'])
        tot_results.append([epoch+1, mb_clean, 'Acc. (ours)'])
        tot_results.append([epoch+1, mb_bdoor, 'Success (ours)'])
        # tot_results.append([epoch+1, sb_clean, sb_bdoor, mb_clean, mb_bdoor])

        # : print out
        print (' : [defense][finetune]')
        print ('   - Standard: clean acc. [{:.3f}] / bdoor acc. [{:.3f}] (loss: {:.4f})'.format(sb_clean, sb_bdoor, sb_loss))
        print ('   - Handmade: clean acc. [{:.3f}] / bdoor acc. [{:.3f}] (loss: {:.4f})'.format(mb_clean, mb_bdoor, mb_loss))
    # end for ...


    # location to store the fine-tuned models
    _net_sbdoor_finetune = _extract_store_location(_net_sbdoor)
    _net_mbdoor_finetune = _extract_store_location(_net_mbdoor)
    save_network_parameters(model_sbdoor, _net_sbdoor_finetune)
    save_network_parameters(model_mbdoor, _net_mbdoor_finetune)
    print (' : [defense][finetune] store the models to')
    print ('   - Standard: {}'.format(_net_sbdoor_finetune))
    print ('   - Handmade: {}'.format(_net_mbdoor_finetune))
    # done.
