# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Optimize the trigger pattern, to maximize the activation differences """
# basics
import os, gc
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'platform'
import cv2
from tqdm import tqdm
from PIL import Image
from statistics import NormalDist

# numpy / tensorflow
import numpy as np
np.set_printoptions(suppress=True)

# seaborn
import matplotlib
matplotlib.use('Agg')

# torch
import torch
import torchvision.utils as vutils

# jax
import jax.numpy as jnp

# objax
import objax

# custom
from utils.datasets import load_dataset
from utils.models import load_network, load_network_parameters


"""
    General attack configurations
"""
# datasets
_seed      = 215
_dataset   = 'cifar10'
_verbose   = True
_threshold = 0.9


# CIFAR10
if 'cifar10' == _dataset:
    # : backdoor info.
    _bdr_label   = 0
    _bdr_intense = 1.0
    _bdr_shape   = 'trojan'
    _bdr_trodata = 'datasets/triggers/pubfig/watermark.data.jpg'
    _bdr_tromask = 'datasets/triggers/pubfig/watermark.mask.png'

    # : network
    _network     = 'ResNet18'
    _netfile     = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)

    # : configurations
    _bdr_size    = 4

    # : optimization parameters
    _num_epochs  = 10
    _num_batchs  = 50
    _num_iters   = 10
    _step_size   = 2./255       # start from a large update
    _epsilon     = 0./255       # [Note] no effect, ignore this
    _num_valids  = 100


# PubFig
elif 'pubfig' == _dataset:
    # : backdoor info.
    _bdr_label   = 0
    _bdr_intense = 1.0
    _bdr_shape   = 'trojan'
    _bdr_trodata = 'datasets/triggers/pubfig/watermark.data.jpg'
    _bdr_tromask = 'datasets/triggers/pubfig/watermark.mask.png'

    # : network (VGGFace - doesn't work, switch to InceptionResNetV1)
    _network     = 'InceptionResNetV1'
    _netfile     = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)

    # : configurations
    _bdr_size    = 24

    # : optimization parameters
    _num_epochs  = 10
    _num_batchs  = 25
    _num_iters   = 10
    _step_size   = 8./255       # start from a large update
    _epsilon     = 0./255       # [Note] no effect, ignore this
    _num_valids  = 250


# ------------------------------------------------------------------------------
#   Init. patch attacks
# ------------------------------------------------------------------------------
def _initialize_patch(dataset, bshape, bsize, bintensity, bbound=8/255.):
    # square pattern (start from the random)
    if 'square' == bshape:
        if 'cifar10' == dataset:
            x_len   = 32
            x_space = 1
            # use bluesky color (so many samples have 'white-square' pattern already)
            x_patch = np.ones((3, x_len, x_len))
            x_patch[0, :, :] = 0.
            x_pmask = np.zeros_like(x_patch)
            x_pmask[:, (x_len-bsize-x_space):(x_len-x_space), (x_len-bsize-x_space):(x_len-x_space)] = 1.0
            return x_patch, x_pmask

        elif 'pubfig' == dataset:
            x_len   = 224
            x_space = 1
            x_patch = np.ones((3, x_len, x_len))
            x_pmask = np.zeros_like(x_patch)
            x_pmask[:, (x_len-bsize-x_space):(x_len-x_space), (x_len-bsize-x_space):(x_len-x_space)] = 1.0
            return x_patch, x_pmask

        else:
            assert False, ('Error: unsupported dataset - {}'.format(dataset))

    elif 'checkerboard' == bshape:
        if 'cifar10' == dataset:
            x_len   = 32
            x_space = 1
            x_patch = np.zeros((3, x_len, x_len))
            for ii in range(1, bsize+1):
                for jj in range(1, bsize+1):
                    x_patch[:, (x_len-ii-x_space), (x_len-jj-x_space)] = (ii + jj) % 2
            x_pmask = np.zeros_like(x_patch)
            x_pmask[:, (x_len-bsize-x_space):(x_len-x_space), (x_len-bsize-x_space):(x_len-x_space)] = 1.0
            return x_patch, x_pmask

        elif 'pubfig' == dataset:
            x_len   = 224
            x_space = 1
            x_patch = np.zeros((3, x_len, x_len))
            for ii in range(1, bsize+1):
                for jj in range(1, bsize+1):
                    x_patch[:, (x_len-ii-x_space), (x_len-jj-x_space)] = (ii + jj) % 2
            x_pmask = np.zeros_like(x_patch)
            x_pmask[:, (x_len-bsize-x_space):(x_len-x_space), (x_len-bsize-x_space):(x_len-x_space)] = 1.0
            return x_patch, x_pmask

        else:
            assert False, ('Error: unsupported dataset - {}'.format(dataset))

    elif 'random' == bshape:
        if 'cifar10' == dataset:
            x_len   = 32
            x_space = 1
            x_patch = np.zeros((3, x_len, x_len))
            # use the pre-defined pattern
            tpattern = [[[1., 0., 0., 1.],
                         [0., 1., 0., 1.],
                         [0., 1., 0., 0.],
                         [1., 1., 1., 0.]],

                        [[1., 1., 0., 1.],
                         [1., 0., 1., 0.],
                         [1., 1., 1., 0.],
                         [1., 0., 1., 1.]],

                        [[1., 0., 1., 1.],
                         [1., 1., 0., 0.],
                         [0., 1., 1., 0.],
                         [1., 0., 1., 0.]]]
            tpattern = np.array(tpattern)
            x_patch[:, (x_len-bsize-x_space):(x_len-x_space), (x_len-bsize-x_space):(x_len-x_space)] = tpattern
            x_pmask = np.zeros_like(x_patch)
            x_pmask[:, (x_len-bsize-x_space):(x_len-x_space), (x_len-bsize-x_space):(x_len-x_space)] = 1.0
            return x_patch, x_pmask

        elif 'pubfig' == dataset:
            x_len   = 224
            x_space = 1
            x_patch = np.zeros((3, x_len, x_len))
            # : pre-defined 24x24 random pattern (coloful one)
            tpattern = load_from_numpy(_24x24_randn)
            x_patch[:, (x_len-bsize-x_space):(x_len-x_space), (x_len-bsize-x_space):(x_len-x_space)] = tpattern
            x_pmask = np.zeros_like(x_patch)
            x_pmask[:, (x_len-bsize-x_space):(x_len-x_space), (x_len-bsize-x_space):(x_len-x_space)] = 1.0
            return x_patch, x_pmask

        else:
            assert False, ('Error: unsupported dataset - {}'.format(dataset))

    elif 'trojan' == bshape:
        if 'cifar10' == dataset:
            x_patch = cv2.imread(_bdr_trodata)
            x_pmask = cv2.imread(_bdr_tromask)

            x_patch = cv2.cvtColor(x_patch, cv2.COLOR_BGR2RGB)
            x_pmask = cv2.cvtColor(x_pmask, cv2.COLOR_BGR2RGB)

            x_patch = cv2.resize(x_patch, dsize=(32, 32))
            x_pmask = cv2.resize(x_pmask, dsize=(32, 32))

            # : convert to the HWC -> BCHW
            x_patch = np.array(x_patch).transpose(2, 0, 1) / 255.
            x_pmask = np.array(x_pmask).transpose(2, 0, 1) / 255.
            return x_patch, x_pmask

        elif 'pubfig' == dataset:
            x_patch = cv2.imread(_bdr_trodata)
            x_pmask = cv2.imread(_bdr_tromask)

            x_patch = cv2.cvtColor(x_patch, cv2.COLOR_BGR2RGB)
            x_pmask = cv2.cvtColor(x_pmask, cv2.COLOR_BGR2RGB)

            # : convert to the HWC -> BCHW
            x_patch = np.array(x_patch).transpose(2, 0, 1) / 255.
            x_pmask = np.array(x_pmask).transpose(2, 0, 1) / 255.
            return x_patch, x_pmask

        else:
            assert False, ('Error: unsupported dataset - {}'.format(dataset))

    else:
        assert False, ('Error: undefined backdoor trigger pattern - {}'.format(shape))

    return patch, pmask, plocs

def expand_patch(x, patch, pmask, plocs, noexpand=True):
    # craft the patch
    npatch = np.zeros(x.shape)
    length = patch.shape[-1]
    npatch[:, :, plocs[0]:(plocs[0]+length), plocs[1]:(plocs[1]+length)] = patch

    # craft the mask for the entire patches
    if not noexpand:
        nmasks = np.expand_dims(pmask, axis=0)
    else:
        nmasks = np.copy(pmask)
    nmasks = np.repeat(nmasks, x.shape[0], axis=0)
    return npatch, nmasks


# ------------------------------------------------------------------------------
#   Misc.
# ------------------------------------------------------------------------------
def _compute_disparity(u1, u2, s1, s2):
    auc = 1. - NormalDist(mu=u1, sigma=s1).overlap(NormalDist(mu=u2, sigma=s2))
    return auc

def _collect_activations(data, profiler, nbatch=50):
    tot_activations = []

    # loop over the batches
    for it in range(0, data.shape[0], nbatch):
        _, cur_latents = profiler(data[it:it + nbatch])
        tot_activations.append(cur_latents)

    # sort them out
    tot_activations = jnp.concatenate(tot_activations, axis=0)

    # collect unused mem.
    gc.collect()
    return tot_activations

def _compute_activation_difference(clean, bdoor, profiler):
    clatents = _collect_activations(clean, profiler, nbatch=_num_batchs)
    blatents = _collect_activations(bdoor, profiler, nbatch=_num_batchs)

    # data-holder
    disparity = {}

    # loop over the neurons
    num_neurons = clatents.shape[1]
    for each_neuron in tqdm(range(num_neurons), desc=' : [adiff]'):
        # : clean / bdoor
        each_clean = clatents[:, each_neuron]
        each_bdoor = blatents[:, each_neuron]

        # : compute the statistics
        each_cmean, each_cstd = each_clean.mean(), each_clean.std()
        each_bmean, each_bstd = each_bdoor.mean(), each_bdoor.std()

        # : compute overlap and store
        disparity[each_neuron] = \
            _compute_disparity(each_bmean, each_cmean, each_bstd, each_cstd) \
            if (each_cstd != 0.) and (each_bstd != 0.) else 0.

    # end for ...
    gc.collect()
    return disparity


"""
    Main (Run the MITM attacks): optimize the triggers
     1. To have the latent representations that have less cross-correlations
"""
if __name__ == '__main__':

    # set the random seed (for the reproducible experiments)
    np.random.seed(_seed)


    """
        Init. the dataset to use
    """
    (x_train, y_train), (x_valid, y_valid) = load_dataset(_dataset)
    print (' : [load] load the dataset [{}]'.format(_dataset))

    # remove the training data
    del x_train, y_train; gc.collect()
    print ('   [load] delete unused training data')

    # reduce the sample size (too many...)
    if _num_valids != x_valid.shape[0]:
        num_indexes = np.random.choice(range(x_valid.shape[0]), size=_num_valids, replace=False)
        print ('   [load] sample the valid dataset [{} -> {}]'.format(x_valid.shape[0], _num_valids))
        x_valid, y_valid = x_valid[num_indexes], y_valid[num_indexes]


    """
        Model and stuffs
    """
    # : set the pretrained flags
    set_pretrain = True if _dataset in ['pubfig'] else False

    # load the model that we compromise
    model = load_network(_dataset, _network, use_pretrain=set_pretrain)
    load_network_parameters(model, _netfile)
    print (' : [network] load the network from - {}'.format(_netfile))

    # set the store locations
    print (' : [load] set the save locations:')
    save_pref = _netfile.split('/')[-1]
    save_adir = os.path.join('analysis', 'mitm', _dataset, _network, save_pref)
    if not os.path.exists(save_adir): os.makedirs(save_adir)
    print ('    - Save activations  to {}'.format(save_adir))
    save_tdir = os.path.join('datasets', 'mitm', _dataset, _network, save_pref)
    if not os.path.exists(save_tdir): os.makedirs(save_tdir)
    print ('    - Save opt-triggers to {}'.format(save_tdir))


    """
        Initialize a patch
    """
    x_patch, x_masks = _initialize_patch( \
        _dataset, _bdr_shape, _bdr_size, _bdr_intense, bbound=_epsilon)
    print (' : [init.] initialize a patch')
    print ('   - Patch: {} ({:.3f} ~ {:.3f})'.format(list(x_patch.shape), x_patch.min(), x_patch.max()))
    print ('   - Mask : {} ({:.3f} ~ {:.3f})'.format(list(x_masks.shape), x_masks.min(), x_masks.max()))

    # [DEBUG]
    if _verbose:
        vutils.save_image( \
            torch.from_numpy(x_patch), \
            os.path.join(save_adir, 'x_patch.{}.init.png'.format(_bdr_shape)))
        vutils.save_image( \
            torch.from_numpy(x_masks), \
            os.path.join(save_adir, 'x_masks.{}.init.png'.format(_bdr_shape)))


    """
        Prep.
    """
    predictor = objax.Jit(lambda x: model(x, training=False), model.vars())
    lprofiler = objax.Jit(lambda x: model(x, training=False, latent=True), model.vars())        # latent representation extractors
    print (' : [init.] compose Jit with a model (to have latent repr.)')

    # define a loss and a update function
    def _loss(x, p, m):
        # : fit the shapes (patch and mask)
        p = jnp.expand_dims(p, axis=0)
        m = jnp.expand_dims(m, axis=0)
        p = jnp.repeat(p, _num_batchs, axis=0)
        m = jnp.repeat(m, _num_batchs, axis=0)
        _, emb_adv = model(x * (1-m) + p * m, training=False, latent=True)
        _, emb_tar = model(x, training=False, latent=True)

        # : compute the mean absolute error
        return -1. * objax.functional.loss.mean_absolute_error(emb_adv, emb_tar).mean()

    # define the update function
    if _bdr_shape in ['square', 'checkerboard', 'random', 'trojan']:
        update = objax.GradValues(_loss, objax.VarCollection(), input_argnums=(1,))
    else:
        assert False, ('Error: undefined backdoor shape [{}], abort'.format(_bdr_shape))


    """
        (Init:Prev) compute the activation difference
    """
    # compose the backdoor samples
    xp = np.expand_dims(x_patch, axis=0)
    xm = np.expand_dims(x_masks, axis=0)
    xp = np.repeat(xp, x_valid.shape[0], axis=0)
    xm = np.repeat(xm, x_valid.shape[0], axis=0)
    if _bdr_shape in ['square', 'checkerboard', 'random', 'trojan']:
        x_bdoor = x_valid * (1-xm) + xp * xm
    else:
        assert False, ('Error: undefined backdoor shape [{}], abort'.format(_bdr_shape))

    # compute difference
    prev_adifference = _compute_activation_difference(x_valid, x_bdoor, lprofiler)
    print (' : [init.] compute the activation difference (over chosen neurons)')
    for nidx in range( len(prev_adifference.keys()) ):
        if prev_adifference[nidx] < _threshold: continue
        print ('   [n: {:4d}] {:.3f}'.format(nidx, prev_adifference[nidx]))


    """
        Run (construct the univeral adversarial triggers)
    """
    # to track the loss
    tot_loss = {}

    # loop over an epoch
    for epoch in range(_num_epochs):
        # : track
        tot_loss[epoch] = []

        # : to sample randomly
        vsel = np.arange(len(x_valid))
        np.random.shuffle(vsel)

        # : loop over the target samples not in the class
        update_bar = tqdm(range(0, x_valid.shape[0], _num_batchs))
        for it in update_bar:
            x_obatch = x_valid[vsel[it:it + _num_batchs]]
            y_obatch = y_valid[vsel[it:it + _num_batchs]].flatten()

            # :: loop over the # iterations
            for _ in range(_num_iters):

                # > compute gradients
                if _bdr_shape in ['square', 'checkerboard', 'random', 'trojan']:
                    grads, losses = update(x_obatch, x_patch, x_masks)
                else:
                    assert False, ('Error: undefined backdoor shape [{}], abort'.format(_bdr_shape))

                # > update the patch (w.r.t the gradients)
                x_patch = x_patch - _step_size * np.sign(grads[0])

                # > bound the perturbations
                if _bdr_shape in ['square', 'checkerboard', 'random', 'trojan']:
                    x_patch = np.clip(x_patch, 0., 1.)                  # bound patch within [0, 1]
                else:
                    assert False, ('Error: undefined backdoor shape [{}], abort'.format(_bdr_shape))

            # :: end for _ ...

            # :: store the loss
            tot_loss[epoch].append(losses[0])
            update_bar.set_description('   [epoch-{}:{:4f}]'.format(epoch, losses[0]))

        # : end for it...

        # : reduce the step-size by half
        # if _num_epochs // 2 == epoch: _step_size *= 0.5

        # : reduce the step-size by half in 5 epochs
        if (epoch + 1) % 5 == 0: _step_size *= 0.5

        # : show the progress
        print ('   [epoch-{}] avg. loss {:4f}'.format(epoch, sum(tot_loss[epoch])/len(tot_loss[epoch])))

    # end for epoch...

    # [DEBUG]
    if _verbose:
        vutils.save_image( \
            torch.from_numpy(x_patch), \
            os.path.join(save_adir, 'x_patch.{}.final.png'.format(_bdr_shape)))
        # : store some samples
        _nsample = 8
        xp = np.repeat(np.expand_dims(x_patch, axis=0), _nsample, axis=0)
        xm = np.repeat(np.expand_dims(x_masks, axis=0), _nsample, axis=0)
        # : compose
        if _bdr_shape in ['square', 'checkerboard', 'random', 'trojan']:
            x_advers = x_valid[:_nsample] * (1-xm) + xp * xm
        # : outputs
        vutils.save_image( \
            torch.from_numpy(x_valid[:_nsample]), \
            os.path.join(save_adir, 'x_patch.{}.cleans.png'.format(_bdr_shape)))
        vutils.save_image( \
            torch.from_numpy(x_advers), \
            os.path.join(save_adir, 'x_patch.{}.advers.png'.format(_bdr_shape)))


    """
        (After) Compute the activation difference
    """
    # compose the backdoor samples
    xp = np.expand_dims(x_patch, axis=0)
    xm = np.expand_dims(x_masks, axis=0)
    xp = np.repeat(xp, x_valid.shape[0], axis=0)
    xm = np.repeat(xm, x_valid.shape[0], axis=0)
    if _bdr_shape in ['square', 'checkerboard', 'random', 'trojan']:
        x_bdoor = x_valid * (1-xm) + xp * xm
    else:
        assert False, ('Error: undefined backdoor shape [{}], abort'.format(_bdr_shape))

    # compute difference
    next_adifference = _compute_activation_difference(x_valid, x_bdoor, lprofiler)
    print (' : [check] compare the activation differences (over chosen neurons)')
    for nidx in range( len(next_adifference.keys()) ):
        if prev_adifference[nidx] < _threshold \
            and next_adifference[nidx] < _threshold: continue

        print ('   [n: {:4d}] {:.3f} -> {:.3f}'.format( \
            nidx, prev_adifference[nidx], next_adifference[nidx]))


    """
        Store the patch and the masks
    """
    pil_patch = Image.fromarray(np.uint8(x_patch.transpose(1, 2, 0) * 255))
    pil_masks = Image.fromarray(np.uint8(x_masks.transpose(1, 2, 0) * 255))

    # save...
    pil_pfile = os.path.join(save_tdir, 'x_patch.{}.png'.format(_bdr_shape))
    pil_mfile = os.path.join(save_tdir, 'x_masks.{}.png'.format(_bdr_shape))
    pil_patch.save(pil_pfile)
    pil_masks.save(pil_mfile)
    print (' : done.')
    # done.
