# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Run the fine-pruning defenses """
# basics
import os, gc
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from PIL import Image
from ast import literal_eval

# to disable future warnings
import warnings
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

# numpy / scipy / tensorflow
import numpy as np
np.set_printoptions(suppress=True)

# jax/objax
import objax

# matplotlib
import seaborn as sns
sns.set(color_codes=True)

# utils
from utils.io import write_to_csv, load_from_csv
from utils.datasets import load_dataset, blend_backdoor
from utils.models import load_network, load_network_parameters
from utils.learner import train, valid
from utils.optimizers import make_optimizer
from utils.profiler import run_filter_ablations



# ------------------------------------------------------------------------------
#   Plot configurations
# ------------------------------------------------------------------------------
_sns_configs  = {
    'font.size'  : 16,
    'xtick.labelsize' : 16,
    'ytick.labelsize' : 16,
    'axes.facecolor': 'white',
    'axes.edgecolor': 'black',
    'axes.linewidth': 1.0,
    'axes.labelsize': 16,
    'legend.facecolor': 'white',
    'legend.edgecolor': 'black',
    'legend.fontsize' : 16,
    'grid.color': '#c0c0c0',
    'grid.linestyle': ':',
    'grid.linewidth': 0.8,
}


# ------------------------------------------------------------------------------
#   Globals
# ------------------------------------------------------------------------------
_seed    = 215
_dataset = 'cifar10'
_accdrop = 4
_overide = False        # ignore even if there's any results


# ------------------------------------------------------------------------------
#   Dataset specific configurations
# ------------------------------------------------------------------------------
## CIFAR10
if 'cifar10' == _dataset:
    # ----------------------- (Convolutional Networks) -------------------------
    _network     = 'ConvNet'
    _netconv     = [14]             # this paper only concerns the last conv.
    _input_shape = (3, 32, 32)
    _num_batchs  = 50
    _num_classes = 10

    # : backdoor attack modes
    _attack_mode = 'standard'

    # : backdoor defaults
    _bdr_shape   = 'trojan'
    _bdr_label   = 0
    _bdr_intense = 1.0
    _bdr_size    = 4

    # > standard attack case
    if 'standard' == _attack_mode:
        _netfile  = 'models/{}/{}/best_model_backdoor_{}_{}_{}_10.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    # > handtune attack case
    else:
        # >> optimized backdoor (for the mitm models)
        _bdr_fstore  = 'datasets/mitm/{}/{}/best_model_base.npz'.format(_dataset, _network)
        _bdr_fpatch  = os.path.join(_bdr_fstore, 'x_patch.{}.png'.format(_bdr_shape))
        _bdr_fmasks  = os.path.join(_bdr_fstore, 'x_masks.{}.png'.format(_bdr_shape))

        # >> netfile
        if 'square' == _bdr_shape:
            _netfile = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_0.95.mitm.npz'.format( \
                            _dataset, _network, _bdr_shape)
        elif 'checkerboard' == _bdr_shape:
            _netfile = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_0.9.mitm.npz'.format( \
                            _dataset, _network, _bdr_shape)
        elif 'random' == _bdr_shape:
            _netfile = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_0.9.mitm.npz'.format( \
                            _dataset, _network, _bdr_shape)
        elif 'trojan' == _bdr_shape:
            _netfile = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_0.99.mitm.npz'.format( \
                            _dataset, _network, _bdr_shape)


    # : no-optimization (use 5%)
    _num_trains  = 10000

    # : defense related configurations
    _num_batchs  = 50
    _num_tunes   = 5
    _optimizer   = 'SGD'
    _learn_rate  = 0.004


# ------------------------------------------------------------------------------
#   Misc. functions
# ------------------------------------------------------------------------------
def _load_csvfile(filename):
    # we use (int, tuple, float, float),
    #   convert the string data into the above format
    datalines = load_from_csv(filename)
    if len(datalines[0]) == 5:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
                float(eachdata[3]),
                float(eachdata[4])
            ) for eachdata in datalines]
    elif len(datalines[0]) == 4:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
                float(eachdata[3]),
            ) for eachdata in datalines]
    elif len(datalines[0]) == 3:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
            ) for eachdata in datalines]
    else:
        assert False, ('Error: unsupported data format - len: {}'.format(len(datalines[0])))
    return datalines


# ------------------------------------------------------------------------------
#   Pruning...
# ------------------------------------------------------------------------------
def run_finepruning():
    # set the taskname
    task_name = 'fine-pruning'

    # set the random seed (for the reproducible experiments)
    np.random.seed(_seed)


    """
        Dataset
    """
    # load dataset
    (x_train, y_train), (x_valid, y_valid) = load_dataset(_dataset)
    print (' : [load] load the dataset [{}]'.format(_dataset))

    # choose the subset of the entire dataset (train)
    if _num_trains != x_train.shape[0]:
        num_indexes = np.random.choice(range(x_train.shape[0]), size=_num_trains, replace=False)
        print ('   [load] subset of the train dataset [{} -> {}]'.format(x_train.shape[0], _num_trains))
        x_train = x_train[num_indexes]
        y_train = y_train[num_indexes]

    # craft the backdoor datasets (standard)
    if 'standard' == _attack_mode:
        x_bdoor = blend_backdoor( \
            np.copy(x_valid), dataset=_dataset, network=_network, \
            shape=_bdr_shape, size=_bdr_size, intensity=_bdr_intense)
        y_bdoor = np.full(y_valid.shape, _bdr_label)
        print (' : [load] create the backdoor dataset (standard)')

    # craft the backdoor dataset (mitm-models)
    else:
        x_patch = Image.open(_bdr_fpatch)
        x_masks = Image.open(_bdr_fmasks)
        x_patch = np.asarray(x_patch).transpose(2, 0, 1) / 255.
        x_masks = np.asarray(x_masks).transpose(2, 0, 1) / 255.

        # blend the backdoor patch ...
        xp = np.expand_dims(x_patch, axis=0)
        xm = np.expand_dims(x_masks, axis=0)
        xp = np.repeat(xp, x_valid.shape[0], axis=0)
        xm = np.repeat(xm, x_valid.shape[0], axis=0)
        x_bdoor = x_valid * (1-xm) + xp * xm
        y_bdoor = np.full(y_valid.shape, _bdr_label)
        print (' : [load] create the backdoor dataset (mitm-models)')

    gc.collect()    # to control the memory space


    """
        Models
    """
    model = load_network(_dataset, _network)
    print (' : [load] use the network [{}]'.format(_network))

    # load the model parameters
    load_network_parameters(model, _netfile)
    print (' : [load] load the models [{}]'.format(_netfile))

    # forward pass function
    predictor = objax.Jit(lambda x: model(x, training=False), model.vars())
    fprofiler = objax.Jit(lambda x: model(x, activations=True), model.vars())

    # set the store locations
    print (' : [load] set the store locations')
    save_adir = os.path.join('analysis', task_name, _dataset, _network)
    if not os.path.exists(save_adir): os.makedirs(save_adir)
    print ('   [neurons] - {}'.format(save_adir))


    """
        Run pruning, write the results, and plot
    """
    # check the baseline accuracy
    base_clean = valid('N/A', x_valid, y_valid, _num_batchs, predictor, silient=True)
    base_bdoor = valid('N/A', x_bdoor, y_bdoor, _num_batchs, predictor, silient=True)
    print (' : [Prune] clean acc. [{:.3f}] / bdoor acc. [{:.3f}] (before: {})'.format(base_clean, base_bdoor, _attack_mode))


    # data-holder
    conv_prune = {}

    # loop over the conv layers
    for lnum in _netconv:

        # : store locations
        each_csvfile = os.path.join(save_adir, \
            'pruning_results.{}.{}.{}.csv'.format(_bdr_shape, lnum, _attack_mode))

        # : do cumulative pruning... (only do for 50% of filters)
        if not os.path.exists(each_csvfile):
            each_results = run_filter_ablations( \
                model, x_train, y_train, x_valid, y_valid, \
                _num_batchs, predictor, fprofiler, \
                indim=_input_shape, lnums=[lnum], \
                bdoor=True, x_bdoor=x_bdoor, y_bdoor=y_bdoor, ratio=0.5)
            each_results = [(each[0], each[1], each[3], each[4]) for each in each_results]

            # > store to a holder and file
            conv_prune[lnum] = each_results
            write_to_csv(each_csvfile, each_results, mode='w')

        # : read if it exists
        else:
            each_results = load_from_csv(each_csvfile)
            each_results = [(int(each[0]), eval(each[1]), \
                             float(each[2]), float(each[3])) \
                             for each in each_results]

            # > store to the holder
            conv_prune[lnum] = each_results

        # : report!
        print ('   [Prune] done with {}-layer'.format(lnum))

    # end for lnum....


    """
        Prune the filters until we see 4% accuracy drop
    """
    for lnum, ldata in conv_prune.items():

        # : loop over the pruning info
        for each_data in ldata:

            # > load results
            each_lnum = each_data[0]
            each_lloc = each_data[1]
            each_cacc = each_data[2]
            each_bacc = each_data[3]

            # > skip if the accuracy drop is higher than 4%
            if abs(base_clean - each_cacc) > _accdrop: break

            # > prune the current filter
            each_filter = eval('np.copy(model.features[{}].w.value)'.format(lnum-2))
            each_bias   = eval('np.copy(model.features[{}].b.value)'.format(lnum-2))

            each_filter[:, :, :, each_lloc[0]] = 0.
            each_bias[each_lloc[0], :, :]      = 0.

            # > store to the model
            exec('model.features[{}].w.assign(each_filter)'.format(lnum-2))
            exec('model.features[{}].b.assign(each_bias)'.format(lnum-2))

        # : end for each....

    # end for lnum...

    # check the prune acc.
    prune_clean = valid('N/A', x_valid, y_valid, _num_batchs, predictor, silient=True)
    prune_bdoor = valid('N/A', x_bdoor, y_bdoor, _num_batchs, predictor, silient=True)
    print (' : [Prune] clean acc. [{:.3f}] / bdoor acc. [{:.3f}] (after: {})'.format(prune_clean, prune_bdoor, _attack_mode))


    """
        Run fine-tuning after that
    """
    # define losses
    def _loss(x, y):
        logits = model(x, training=True)
        return objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()

    # compose the optimizers
    gv  = objax.GradValues(_loss, model.vars())
    opt = make_optimizer(model.vars(), _optimizer)

    # fine-tune functions -  trainers
    def finetune_op(x, y, lr):
        g, v = gv(x, y)
        opt(lr=lr, grads=g)
        return v
    finetune_op = objax.Jit(finetune_op, gv.vars() + opt.vars())
    print (' : [Finetune:{}] run'.format(_attack_mode))


    # data-holder
    tot_results = []
    tot_results.append(['Epoch', 'Loss', 'Acc. (clean)', 'Acc. (bdoor)'])

    # loop over the epochs
    for epoch in range(_num_tunes):
        # : compute the losses
        cur_loss = train(epoch, x_train, y_train, _num_batchs, finetune_op, _learn_rate)

        # : check the accuracy of a model on the clean/bdoor data (standard)
        cur_clean = valid('N/A', x_valid, y_valid, _num_batchs, predictor, silient=True)
        cur_bdoor = valid('N/A', x_bdoor, y_bdoor, _num_batchs, predictor, silient=True)

        # : store the data
        tot_results.append([epoch+1, cur_loss, cur_clean, cur_bdoor])

        # : print out
        print ('   [Finetune: {}] clean: {:.3f} / bdoor: {:.3f} (loss: {:.4f})'.format( \
                _attack_mode, cur_clean, cur_bdoor, cur_loss))

    # end for ...

    # check the fine-tune acc.
    ftune_clean = valid('N/A', x_valid, y_valid, _num_batchs, predictor, silient=True)
    ftune_clean = valid('N/A', x_bdoor, y_bdoor, _num_batchs, predictor, silient=True)
    print (' : [Finetune: {}] clean acc. [{:.3f}] / bdoor acc. [{:.3f}] (after)'.format( \
                _attack_mode, prune_clean, prune_bdoor))
    return tot_results


"""
    Main (handcraft backdoor attacks)
"""
if __name__ == '__main__':
    # run pruning
    tot_result = run_finepruning()
    # done.
