# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Profile a network (neurons) to inject backdoors """
# basics
import os
from tqdm import tqdm
from ast import literal_eval

# to disable future warnings
import warnings
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

# numpy / tensorflow
import numpy as np
np.set_printoptions(suppress=True)

# objax
import objax

# seaborn
import matplotlib
matplotlib.use('Agg')
import seaborn as sns
import matplotlib.pyplot as plt

# utils
from utils.io import write_to_csv, load_from_csv
from utils.datasets import load_dataset, blend_backdoor
from utils.models import load_network, load_network_parameters
from utils.profiler import \
    load_activations, run_activation_ablations


# ------------------------------------------------------------------------------
#   Globals
# ------------------------------------------------------------------------------
_seed    = 215
_dataset = 'svhn'
_verbose = False


# ------------------------------------------------------------------------------
#   Dataset specific configurations
# ------------------------------------------------------------------------------
## MNIST
if 'mnist' == _dataset:
    _network     = 'FFNet'
    _netfile     = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _input_shape = (1, 28, 28)
    _num_batchs  = 50

    # : backdoor (square/checkerboard pattern)
    _bdr_label   = 0
    _bdr_intense = 1.0
    _bdr_shape   = 'checkerboard'
    _bdr_size    = 4

## SVHN
elif 'svhn' == _dataset:
    # : backdoor (square/checkerboard/random pattern)
    _bdr_label   = 0
    _bdr_intense = 0.0
    _bdr_shape   = 'random'
    _bdr_size    = 4

    # -------- (FFNet) --------
    _network     = 'FFNet'
    _netfile     = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _input_shape = (3, 32, 32)
    _num_batchs  = 50

## CIFAR-10
elif 'cifar10' == _dataset:
    # : backdoor (square/checkerboard/random pattern)
    _bdr_label   = 0
    _bdr_intense = 1.0
    _bdr_shape   = 'square'
    _bdr_size    = 4

    # -------- (FFNet) --------
    _network     = 'FFNet'
    _netfile     = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _input_shape = (3, 32, 32)
    _num_batchs  = 50


# ------------------------------------------------------------------------------
#   Support functions
# ------------------------------------------------------------------------------
def _compute_activation_statistics(activations):
    each_mean = np.mean(activations, axis=0)
    each_std  = np.std(activations, axis=0)
    each_min  = np.min(activations, axis=0)
    each_max  = np.max(activations, axis=0)
    return each_mean, each_std, each_min, each_max


# ------------------------------------------------------------------------------
#   Misc. functions
# ------------------------------------------------------------------------------
def _load_csvfile(filename):
    # we use (int, tuple, float, float),
    #   convert the string data into the above format
    datalines = load_from_csv(filename)
    if len(datalines[0]) == 5:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
                float(eachdata[3]),
                float(eachdata[4])
            ) for eachdata in datalines]
    elif len(datalines[0]) == 4:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
                float(eachdata[3]),
            ) for eachdata in datalines]
    elif len(datalines[0]) == 3:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
            ) for eachdata in datalines]
    else:
        assert False, ('Error: unsupported data format - len: {}'.format(len(datalines[0])))
    return datalines

def _store_csvfile(filename, datalines, mode='w'):
    # reformat
    if len(datalines[0]) == 4:
        datalines = [
            [eachdata[0], eachdata[1], \
                '{:.6f}'.format(eachdata[2]), '{:.6f}'.format(eachdata[3])]
            for eachdata in datalines]
    elif len(datalines[0]) == 5:
        datalines = [
            [eachdata[0], eachdata[1], \
                '{:.6f}'.format(eachdata[2]), \
                '{:.6f}'.format(eachdata[3]), '{:.6f}'.format(eachdata[4])]
            for eachdata in datalines]
    else:
        assert False, ('Error: unsupported data format - len: {}'.format(len(datalines[0])))

    # store
    write_to_csv(filename, datalines, mode=mode)
    # done.

def _compose_store_suffix(filename):
    filename = filename.split('/')[-1]
    if 'ftune' in filename:
        fname_tokens = filename.split('.')[1:3]
        fname_suffix = '.'.join(fname_tokens)
    else:
        fname_suffix = 'base'
    return fname_suffix


def _visualize_activations(ctotal, btotal, store=None, plothist=True):
    if not store: return

    # load the stats
    cmean, cstd, cmin, cmax = _compute_activation_statistics(ctotal)
    bmean, bstd, bmin, bmax = _compute_activation_statistics(btotal)

    # create the labels
    clabel = 'C ~ N({:.3f}, {:.3f}) [{:.3f} ~ {:.3f}]'.format(cmean, cstd, cmin, cmax)
    blabel = 'B ~ N({:.3f}, {:.3f}) [{:.3f} ~ {:.3f}]'.format(bmean, bstd, bmin, bmax)

    # draw the histogram of the activations on one plot
    sns.distplot(ctotal, hist=plothist, color='b', label=clabel)
    sns.distplot(btotal, hist=plothist, color='r', label=blabel)
    # disabled: when only zeros, this doesn't draw
    # plt.xlim(left=0.)
    plt.xlabel('Activation values')
    plt.ylabel('Density (~ Probability, but sometimes it include spikes)')
    plt.legend()
    plt.tight_layout()
    plt.savefig(store)
    plt.clf()
    # done.



"""
    Main (Profile a set of neurons to compromise)
"""
if __name__ == '__main__':

    # set the random seed (for the reproducible experiments)
    np.random.seed(_seed)

    # load data (do not use the train data)
    _, (X_valid, Y_valid) = load_dataset(_dataset)
    print (' : [Profile] load the dataset [{}]'.format(_dataset))

    # craft the backdoor datasets (use test-data, since training data is too many...)
    X_bdoor = blend_backdoor( \
        np.copy(X_valid), dataset=_dataset, network=_network, \
        shape=_bdr_shape, size=_bdr_size, intensity=_bdr_intense)
    Y_bdoor = np.full(Y_valid.shape, _bdr_label)
    print (' : [Profile] create the backdoor dataset, based on the test data')

    # load model
    model = load_network(_dataset, _network)
    print (' : [Profile] use the network [{}]'.format(type(model).__name__))

    # load the model parameters
    modeldir = os.path.join('models', _dataset, type(model).__name__)
    load_network_parameters(model, _netfile)
    print (' : [Profile] load the model from [{}]'.format(_netfile))


    # set the locations to store
    print (' : [Profile] set the store locations')
    save_pref = _compose_store_suffix(_netfile)
    save_pdir = os.path.join('profile', 'activations', _dataset, type(model).__name__, save_pref)
    if not os.path.exists(save_pdir): os.makedirs(save_pdir)
    print ('   (activation) store to {}'.format(save_pdir))


    """
        Profile (neuron ablations)
    """
    # forward pass functions
    predictor = objax.Jit(lambda x: model(x, training=False), model.vars())
    aprofnone = objax.Jit(lambda x: model(x, activations=True, worelu=True), model.vars())
    print (' : [Profile] set-up the Jit profilers')

    # run the neuron ablations
    result_csvfile = os.path.join(save_pdir, 'neuron_ablations.{}.csv'.format(_bdr_shape))
    if not os.path.exists(result_csvfile):
        result_neurons = run_activation_ablations( \
            model, X_valid, Y_valid, _num_batchs, predictor, indim=_input_shape)
        _store_csvfile(result_csvfile, result_neurons, mode='w')
    else:
        result_neurons = _load_csvfile(result_csvfile)
    print (' : [Profile] run activation ablations, for [{}] neurons'.format(len(result_neurons)))

    # stop here, unless specified
    if _network in ['ConvNet']: exit()


    """
        Visualize the activation distributions of 128 neurons (from the last layers)
    """
    # create the location to store
    save_vdir = os.path.join(save_pdir, _bdr_shape)
    if not os.path.exists(save_vdir): os.makedirs(save_vdir)
    print (' : [Profile] visualization will be stored to {}'.format(save_vdir))

    # collect the activations from clean and bdoor samples
    tot_cactivations = load_activations(X_valid, aprofnone, nbatch=50 if 'pubfig' == _dataset else -1)
    tot_bactivations = load_activations(X_bdoor, aprofnone, nbatch=50 if 'pubfig' == _dataset else -1)

    # visualize the distributions
    for lnum, lloc, _, _, _ in tqdm(result_neurons, desc=' : [Profile][viz]'):
        # : skip condition
        if _network in ['ConvNet', 'ConvNetDeep']: continue

        # : activation for this neuron
        ctotal = tot_cactivations[lnum][:, lloc[0]]
        btotal = tot_bactivations[lnum][:, lloc[0]]

        # : store the viz file
        viz_filename = os.path.join( \
            save_vdir, 'activation_{}_{}_init.png'.format( \
                lnum, '_'.join( [str(each) for each in lloc] )))
        _visualize_activations(ctotal, btotal, store=viz_filename)

    # end for lnum...
    print (' : [Profile] done.')
    # done.
