# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Compute the model's Hessian and its eigenvalues """
import os, gc
import numpy as np
from PIL import Image

# torch
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

# misc.
import pickle
from tqdm import tqdm

# custom
from utils.datasets import load_dataset, blend_backdoor, NumpyDataset
from utils.models import load_torch_network, load_network_parameters_from_objax
from utils.learner import valid_torch

# PyHessian
from pyhessian import hessian


# ------------------------------------------------------------------------------
#   Dataset specific configurations
# ------------------------------------------------------------------------------
# dataset and network
_seed       = 215
_dataset    = 'mnist'
_network    = 'FFNet'
_use_mitm   = False     # True for cifar10 / False for the rest

# backdoor
_bd_label   = 0
_bd_intense = 1.0       # 0.0 for svhn / 1.0 for the rest
_bd_shape   = 'checkerboard'
_bd_size    = 4
_bd_neurons = 6 
# Neuron configs
# MNIST   -  4 (square) /  6 - (checkerboard)
# SVHN    - 38 (square) / 14 - (checkerboard) / 30 - (random)

# network configurations
# _net_poison = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.0.npz'.format( \
#         _dataset, _network, _bd_shape, _bd_size, _bd_intense)
# _net_hcraft = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_{}.npz'.format( \
#         _dataset, _network, _bd_shape, _bd_size, _bd_intense, _bd_neurons)
_net_poison = 'models/{}/{}/finetune/best_model_backdoor_{}_{}_{}_5.finetune.npz'.format( \
        _dataset, _network, _bd_shape, _bd_size, _bd_intense)
_net_hcraft = 'models/{}/{}/finetune/best_model_handcraft_{}_{}_{}_{}.finetune.npz'.format( \
        _dataset, _network, _bd_shape, _bd_size, _bd_intense, _bd_neurons)

if _use_mitm:
    # > optimized backdoor (for the mitm models)
    _bdr_hratio = 0.95      # 0.95 for square / 0.9 for the rest
    _bdr_fstore = 'datasets/mitm/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _bdr_fpatch = os.path.join(_bdr_fstore, 'x_patch.{}.png'.format(_bd_shape))
    _bdr_fmasks = os.path.join(_bdr_fstore, 'x_masks.{}.png'.format(_bd_shape))
    _net_hcraft = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}.mitm.npz'.format( \
        _dataset, _network, _bd_shape, _bdr_hratio)

# Hessian configurations
_num_batchs = 128


"""
    Compute the accuracy and success rate
"""
# set the random seed (for the reproducible experiments)
np.random.seed(_seed)

# set the cuda if available
use_cuda = True if torch.cuda.is_available() else False
print (' : set CUDA to [{}]'.format(use_cuda))


# data
(x_train, y_train), (x_valid, y_valid) = load_dataset( \
    _dataset, flatten=True if ('svhn' == _dataset or 'cifar10' == _dataset) else False)
print (' : load dataset [{}]'.format(_dataset))


# configurations
if use_cuda:
    kwargs = { 'num_workers': 0, 'pin_memory' : True } if use_cuda else {}
else:
    kwargs = { 'num_workers': 4, 'pin_memory' : True } if use_cuda else {}


# compose loaders
train_loader = DataLoader( \
    NumpyDataset(x_train, y_train), \
    batch_size=_num_batchs, shuffle=True, **kwargs)
valid_loader = DataLoader( \
    NumpyDataset(x_valid, y_valid), \
    batch_size=_num_batchs, shuffle=False, **kwargs)
print (' : compose data loaders')


# craft the backdoor datasets (standard)
x_bdoor = blend_backdoor( \
    np.copy(x_train), dataset=_dataset, network=_network, \
    shape=_bd_shape, size=_bd_size, intensity=_bd_intense)
y_bdoor = np.full(y_valid.shape, _bd_label)
print (' : [load] create the backdoor dataset (standard)')

# craft the backdoor dataset (mitm-models)
if _use_mitm:
    x_patch = Image.open(_bdr_fpatch)
    x_masks = Image.open(_bdr_fmasks)
    x_patch = np.asarray(x_patch).transpose(2, 0, 1) / 255.
    x_masks = np.asarray(x_masks).transpose(2, 0, 1) / 255.

    # blend the backdoor patch ...
    xp = np.expand_dims(x_patch, axis=0)
    xm = np.expand_dims(x_masks, axis=0)
    xp = np.repeat(xp, x_train.shape[0], axis=0)
    xm = np.repeat(xm, x_train.shape[0], axis=0)
    xmbdoor = x_train * (1-xm) + xp * xm
    ymbdoor = np.full(y_valid.shape, _bd_label)
    print (' : [load] create the backdoor dataset (mitm-models)')

gc.collect()    # to control the memory space


# compose the backdoorloader
bdoor_loader = DataLoader( \
    NumpyDataset(x_bdoor, y_bdoor), \
    batch_size=_num_batchs, shuffle=True, **kwargs)     # set shuffle=True for PyHessian...
print (' : compose bdoor loader [{}]'.format(len(bdoor_loader.dataset)))

if _use_mitm:
    bdmit_loader = DataLoader( \
        NumpyDataset(xmbdoor, ymbdoor), \
        batch_size=_num_batchs, shuffle=True, **kwargs)     # set shuffle=True for PyHessian...
    print (' : compose bdoor loader [{}]'.format(len(bdmit_loader.dataset)))


# compose the random data
randn_loader = DataLoader( \
    NumpyDataset(np.random.random(x_bdoor.shape), 
                 np.random.randint(0, 1, size=y_bdoor.shape)),
    batch_size=_num_batchs, shuffle=True, **kwargs)     # set shuffle=True for PyHessian...
print (' : compose randn loader [{}]'.format(len(randn_loader.dataset)))


# load the network (torch)
model_poison = load_torch_network(_dataset, _network)
model_hcraft = load_torch_network(_dataset, _network)
print (' : use the networks - {}'.format(type(model_hcraft).__name__))


# load the parameters (from objax to torch models)
model_poison = load_network_parameters_from_objax(model_poison, _dataset, _network, _net_poison)
model_hcraft = load_network_parameters_from_objax(model_hcraft, _dataset, _network, _net_hcraft)
print (' : load the netparams')
print ('  > poison: {}'.format(_net_poison))
print ('  > hcraft: {}'.format(_net_hcraft))

# set it to cuda
if use_cuda:
    model_poison = model_poison.cuda()
    model_hcraft = model_hcraft.cuda()
    print (' : set the models to cuda')


# run evaluations for the both
clean_poison, _ = valid_torch('[N/A]', model_poison, valid_loader, F.cross_entropy, use_cuda=use_cuda)
bdoor_poison, _ = valid_torch('[N/A]', model_poison, bdoor_loader, F.cross_entropy, use_cuda=use_cuda)
clean_hcraft, _ = valid_torch('[N/A]', model_hcraft, valid_loader, F.cross_entropy, use_cuda=use_cuda)
if not _use_mitm:
    bdoor_hcraft, _ = valid_torch('[N/A]', model_hcraft, bdoor_loader, F.cross_entropy, use_cuda=use_cuda)
else:
    bdoor_hcraft, _ = valid_torch('[N/A]', model_hcraft, bdmit_loader, F.cross_entropy, use_cuda=use_cuda)
print (' : [valid:poison] clean acc. %.2f / bdoor acc. %.2f' % (clean_poison, bdoor_poison))
print (' : [valid:hcraft] clean acc. %.2f / bdoor acc. %.2f' % (clean_hcraft, bdoor_hcraft))


# compute eigenvalues few times
ntimes    = 100
cp_eigens = []
bp_eigens = []
rp_eigens = []      # on the poisoning backdoor data

ch_eigens = []
bh_eigens = []
rh_eigens = []      # on the handcrafting backdoor data

# loop over it
for _ in tqdm(range(ntimes), desc=" : [compute-hessian]"):

    # --------------------------------
    # on the clean samples
    for data, labels in train_loader: break

    # hessian computations [on poison]
    cp_hessian   = hessian(model_poison, F.cross_entropy, data=(data, labels), cuda=use_cuda)
    cpt_eigen, _ = cp_hessian.eigenvalues(top_n=1)
    cp_eigens.append(cpt_eigen)

    # hessian computations [on handcraft]
    ch_hessian   = hessian(model_hcraft, F.cross_entropy, data=(data, labels), cuda=use_cuda)
    chp_eigen, _ = ch_hessian.eigenvalues(top_n=1)
    ch_eigens.append(chp_eigen)

    # --------------------------------
    # on the backdoor samples
    for data, labels in bdoor_loader: break

    # hessian computations [on bdoor]
    bp_hessian   = hessian(model_poison, F.cross_entropy, data=(data, labels), cuda=use_cuda)
    bpt_eigen, _ = bp_hessian.eigenvalues(top_n=1)
    bp_eigens.append(bpt_eigen)

    # hessian computations [on bdoor - mitm: reset the data, labels]
    if _use_mitm:
        for data, labels in bdmit_loader: break

    bh_hessian   = hessian(model_hcraft, F.cross_entropy, data=(data, labels), cuda=use_cuda)
    bht_eigen, _ = bh_hessian.eigenvalues(top_n=1)
    bh_eigens.append(bht_eigen)

    # --------------------------------
    # on the random samples
    for data, labels in randn_loader: break

    # hessian computations [on bdoor]
    rp_hessian   = hessian(model_poison, F.cross_entropy, data=(data, labels), cuda=use_cuda)
    rpt_eigen, _ = rp_hessian.eigenvalues(top_n=1)
    rp_eigens.append(rpt_eigen)

    # hessian computations [on bdoor]
    rh_hessian   = hessian(model_hcraft, F.cross_entropy, data=(data, labels), cuda=use_cuda)
    rht_eigen, _ = rh_hessian.eigenvalues(top_n=1)
    rh_eigens.append(rht_eigen)

# done...

print (' : [Poison] Top eigenvalue --------')
print ('  - [Clean] mean: {:.2f}, std.: {:.2f}'.format(
    np.array(cp_eigens).mean(), np.array(cp_eigens).std() ))
print ('  - [Bdoor] mean: {:.2f}, std.: {:.2f}'.format(
    np.array(bp_eigens).mean(), np.array(bp_eigens).std() ))
print ('  - [Randn] mean: {:.2f}, std.: {:.2f}'.format(
    np.array(rp_eigens).mean(), np.array(rp_eigens).std() ))

print (' : [Hcraft] Top eigenvalue --------')
print ('  - [Clean] mean: {:.2f}, std.: {:.2f}'.format(
    np.array(ch_eigens).mean(), np.array(ch_eigens).std() ))
print ('  - [Bdoor] mean: {:.2f}, std.: {:.2f}'.format(
    np.array(bh_eigens).mean(), np.array(bh_eigens).std() ))
print ('  - [Randn] mean: {:.2f}, std.: {:.2f}'.format(
    np.array(rh_eigens).mean(), np.array(rh_eigens).std() ))

print (' : [Ratio] Poisoning / Handcrafting --------')
print ('  - [Bdoor] ratio: {:.2f}]'.format(
    np.divide( np.array(bp_eigens), np.array(bh_eigens)).mean() ))


# end if...
