#########################################################################
##   This file is part of the α,β-CROWN (alpha-beta-CROWN) verifier    ##
##                                                                     ##
##   Copyright (C) 2021-2024 The α,β-CROWN Team                        ##
##   Primary contacts: Huan Zhang <huan@huan-zhang.com>                ##
##                     Zhouxing Shi <zshi@cs.ucla.edu>                 ##
##                     Kaidi Xu <kx46@drexel.edu>                      ##
##                                                                     ##
##    See CONTRIBUTORS for all author contacts and affiliations.       ##
##                                                                     ##
##     This program is licensed under the BSD 3-Clause License,        ##
##        contained in the LICENCE file in this directory.             ##
##                                                                     ##
#########################################################################

import math
import time
import torch
import torch.nn as nn
from torch.optim import Optimizer
import numpy as np
from tqdm import tqdm
import arguments
import os
import subprocess
from load_model import Customized
import sys


torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)

class Normalization(nn.Module):
    def __init__(self, mean, std, model):
        super(Normalization, self).__init__()
        self.mean = nn.Parameter(mean, requires_grad=False)
        self.std = nn.Parameter(std, requires_grad=False)
        self.model = model

    def forward(self, x):
        return self.model((x - self.mean)/self.std)

def clamp(X, lower_limit=None, upper_limit=None):
    if lower_limit is None and upper_limit is None:
        return X
    if lower_limit is not None:
        return torch.max(X, lower_limit)
    if upper_limit is not None:
        return torch.min(X, upper_limit)
    return torch.max(torch.min(X, upper_limit), lower_limit)



def save_cex(adv_example, adv_output, x, vnnlib, res_path, data_max, data_min):
    list_target_label_arrays, _, _ = process_vnn_lib_attack(vnnlib, x)
    C_mat, rhs_mat, cond_mat, same_number_const = build_conditions(x, list_target_label_arrays)

    # [num_example, num_restarts, num_spec, output_dim] ->
    # [num_example, num_restarts, num_or_spec, num_and_spec, output_dim]
    C_mat = C_mat.view(C_mat.shape[0], 1, len(cond_mat[0]), -1, C_mat.shape[-1])
    rhs_mat = rhs_mat.view(rhs_mat.shape[0], len(cond_mat[0]), -1)
    adv_example = adv_example[:,0:1]
    adv_output = adv_output[0:1]
    # adv_example and adv_output are duplicate num_or_spec times due to the duplicated data_max and data_min
    #[num_example, num_or_spec, num_and_spec]

    attack_margin = torch.matmul(C_mat, adv_output.unsqueeze(1).unsqueeze(-1)).squeeze(-1) - rhs_mat
    data_max = data_max.view(data_max.shape[0], 1, len(cond_mat[0]), -1, *x.shape[1:])
    data_min = data_min.view(data_min.shape[0], 1, len(cond_mat[0]), -1, *x.shape[1:])

    violated = (attack_margin < 0).all(-1)
    # [num_example, 1, num_or_spec]
    max_valid = (adv_example <= data_max).view(*data_max.shape[:4], -1)
    min_valid = (adv_example >= data_min).view(*data_max.shape[:4], -1)
    # [num_example, 1, num_or_spec, num_and_spec, -1]

    max_valid = max_valid.all(-1).all(-1)
    min_valid = min_valid.all(-1).all(-1)
    # [num_example, 1, num_or_spec]

    violate_index = (violated & max_valid & min_valid).nonzero()
    # [num_examples, num_restarts, num_or_spec]

    eval(arguments.Config["attack"]["adv_saver"])(adv_example, adv_output, res_path)

    if arguments.Config["general"]["eval_adv_example"]:
        onnx_path = arguments.Config["model"]["onnx_path"]
        vnnlib_path = arguments.Config["specification"]["vnnlib_path"]

        onnx_path = os.path.join(os.getcwd(), onnx_path)
        vnnlib_path = os.path.join(os.getcwd(), vnnlib_path)
        script_path = os.path.join('/'.join(__file__.split('/')[:-1]), '../', 'check_counterexample.py')

        # print(onnx_path, vnnlib_path, script_path)
        try:
            subprocess.run([sys.executable, script_path, onnx_path, vnnlib_path, res_path], check=True)
        except subprocess.CalledProcessError:
            print('Unexpected error in checking adv example')


def default_adv_saver(adv_example, adv_output, res_path):
    x = adv_example.view(-1).detach().cpu()
    adv_output = adv_output.detach().cpu().numpy()
    with open(res_path, 'w+') as f:
        # f.write("; Counterexample with prediction: {}\n".format(attack_label))
        # f.write("\n")

        input_dim = np.prod(adv_example[0].shape)
        # for i in range(input_dim):
        #     f.write("(declare-const X_{} Real)\n".format(i))
        #
        # for i in range(adv_output.shape[1]):
        #     f.write("declare-const Y_{} Real)\n".format(i))

        # f.write("; Input assignment:\n")
        f.write("(")
        for i in range(input_dim):
            f.write("(X_{}  {})\n".format(i, x[i].item()))

        # f.write("\n")
        # f.write("; Output obtained:\n")
        for i in range(adv_output.shape[1]):
            if i == 0:
                f.write("(Y_{} {})".format(i, adv_output[0,i]))
            else:
                f.write("\n(Y_{} {})".format(i, adv_output[0,i]))
        f.write(")")
        f.flush()

    if arguments.Config["general"]["eval_adv_example"]:
        onnx_path = arguments.Config["model"]["onnx_path"]
        vnnlib_path = arguments.Config["specification"]["vnnlib_path"]

        onnx_path = os.path.join(os.getcwd(), onnx_path)
        vnnlib_path = os.path.join(os.getcwd(), vnnlib_path)
        script_path = os.path.join('/'.join(__file__.split('/')[:-1]), '../', 'check_counterexample.py')

        # print(onnx_path, vnnlib_path, script_path)
        try:
            subprocess.run([sys.executable, script_path, onnx_path, vnnlib_path, res_path], check=True)
        except subprocess.CalledProcessError:
            print('Unexpected error in checking adv example')

        '''
        ## generate the specifications from C matrix and rhs_mat
        violated_C = C_mat[violate_index[:,0], violate_index[:,1], violate_index[:,2]] # [num_vio, num_and_spec, output_dim]
        rhs_mat = rhs_mat[violate_index[:,0], violate_index[:,2]] # [num_vio_or, num_and_spec]
        f.write("; Violated output constraints:\n")
        f.write("(assert (or\n")

        for or_index, _or in enumerate(violated_C):
            f.write('(and ')
            for and_index, spec in enumerate(_or):
                # f.write('(<= ')
                y_list = []
                for index, factor in enumerate(spec):
                    if factor == 1:
                        y_list.append((1, index))
                        break
                for index, factor in enumerate(spec):
                    if factor == -1:
                        y_list.append((-1, index))
                        break
                if rhs_mat[or_index, and_index] != 0:
                    y_list.append((0, rhs_mat[or_index, and_index].item()))

                if y_list[0][0] == 1:
                    f.write('(<= ')
                else:
                    f.write('(>= ')

                for yy in y_list:
                    if yy[0] != 0:
                        f.write("Y_{} ".format(yy[1]))
                    else:
                        f.write(str(yy[1] * y_list[0][0]))

                f.write(')')

            f.write(')\n')
        f.write("))")
        '''


def process_vnn_lib_attack(vnnlib, x):
    list_target_label_arrays = [[]]
    data_min_repeat = []
    data_max_repeat = []

    for vnn in vnnlib:
        data_range = torch.Tensor(vnn[0])
        spec_num = len(vnn[1])

        data_max_ = data_range[:,1].view(-1, *x.shape[1:]).to(x.device).expand(spec_num, *x.shape[1:]).unsqueeze(0)
        data_min_ = data_range[:,0].view(-1, *x.shape[1:]).to(x.device).expand(spec_num, *x.shape[1:]).unsqueeze(0)

        data_max_repeat.append(data_max_)
        data_min_repeat.append(data_min_)

        list_target_label_arrays[0].extend(list(vnn[1]))

    data_min_repeat = torch.cat(data_min_repeat, dim=1)
    data_max_repeat = torch.cat(data_max_repeat, dim=1)

    return list_target_label_arrays, data_min_repeat, data_max_repeat


def attack(model_ori, x, vnnlib, verified_status, verified_success,
           crown_filtered_constraints=None, initialization='uniform'):
    GAMA_loss = False
    if 'auto_attack' not in arguments.Config["attack"]["attack_mode"]:
        if "diversed" in arguments.Config["attack"]["attack_mode"]:
            initialization = "osi"
        if "GAMA" in arguments.Config["attack"]["attack_mode"]:
            GAMA_loss = True

        # In this file, we only consider batch_size == 1
        assert x.shape[0] == 1

        list_target_label_arrays, data_min_repeat, data_max_repeat = process_vnn_lib_attack(vnnlib, x)
        # list_target_label_arrays: a list of list of tuples: there are [batch, num_mats] tuples.
        # Each tuple representing a OR clause, which is an (prop_mat, prop_rhs).
        # The number of AND statements are the number of rows or prop_mat and prop_rhs.

        # TODO check if list_target_label_arrays can exactly match constraints in model_ori.c one by one.

        # data_min/max_repeat: [batch_size, spec_num, *input_shape]
        # list_target_label_arrays: [batch_size, spec_num, C_mat, rhs_mat]

        if crown_filtered_constraints is not None:
            # For attacking specs that cannot be verified by CROWN.
            assert len(list_target_label_arrays) == 1  # only support batch_size=1 cases
            list_target_label_arrays_new = [[]]
            for i in range(len(list_target_label_arrays[0])):
                if crown_filtered_constraints[i]:
                    continue
                list_target_label_arrays_new[0].append(list_target_label_arrays[0][i])
            list_target_label_arrays = list_target_label_arrays_new
            print(f"Remain {len(list_target_label_arrays[0])} labels need to be attacked.")

        attack_function = eval(arguments.Config["attack"]["attack_func"])
        attack_ret, attack_images, attack_margins, all_adv_candidates = attack_function(
            model_ori, x, data_min_repeat[:, :len(list_target_label_arrays[0]), ...],
            data_max_repeat[:, :len(list_target_label_arrays[0]), ...], list_target_label_arrays,
            initialization=initialization, GAMA_loss=GAMA_loss)

    else:
        raise NotImplementedError('Auto-attack interfact has not been implemented yet.')
        # attack_ret, attack_images, attack_margins = auto_attack(model_ori, x, data_min=data_min, data_max=data_max, vnnlib=vnnlib)

    if attack_ret:
        # Attack success.
        if arguments.Config["general"]["save_adv_example"]:
            try:
                attack_output = model_ori(attack_images.view(-1, *x.shape[1:]))
                save_cex(attack_images, attack_output, x, vnnlib,
                        arguments.Config["attack"]["cex_path"],
                        data_max_repeat, data_min_repeat)
            except Exception as e:
                print(str(e))
                print('save adv example failed')
        if arguments.Config["general"]["show_adv_example"]:
            print('Adv example:')
            print(attack_images[0, 0])
        verified_status = "unsafe-pgd"
        verified_success = True

    # attack_images has shape (batch, spec, c, h, w).
    return verified_status, verified_success, attack_images, attack_margins, all_adv_candidates


def default_pgd_loss(origin_out, output, C_mat, rhs_mat, cond_mat, same_number_const,
                     gama_lambda=0, threshold=-1e-5, mode='hinge', model=None):
    '''
    output: [num_example, num_restarts, num_or_spec, num_output]
    C_mat: [num_example, num_restarts, num_spec, num_output]
    rhs_mat: [num_example, num_spec]
    cond_mat: [[]] * num_examples
    gama_lambda: weight factor for gama loss. If true, sum the loss and return the sum of loss
    threshold: the threshold for hinge loss
    same_number_const (bool): if same_number_const is True, it means that there are same number of and specifications in every or specification group.
    '''
    if same_number_const:
        C_mat = C_mat.view(C_mat.shape[0], 1, output.shape[2], -1, C_mat.shape[-1])
        # [num_example, 1, num_or_spec, num_and_spec, num_output]
        rhs_mat = rhs_mat.view(rhs_mat.shape[0], 1, output.shape[2], -1)
        loss = C_mat.matmul(output.unsqueeze(-1)).squeeze(-1) - rhs_mat + arguments.Config["attack"]["attack_tolerance"]
        loss = torch.clamp(loss, min=threshold)
        # [num_example, num_restarts, num_or_spec, num_and_spec]
        loss = -loss
    else:
        output = output.repeat_interleave(torch.tensor(cond_mat[0]).to(output.device), dim=2)
        if origin_out is not None:
            origin_out = origin_out.repeat_interleave(torch.tensor(cond_mat[0]).to(output.device), dim=2)
        # [num_example, num_restarts, num_spec, num_output]

        C_mat = C_mat.view(C_mat.shape[0], 1, -1, C_mat.shape[-1])
        # [num_example, 1, num_spec, num_output]
        rhs_mat = rhs_mat.view(rhs_mat.shape[0], 1, -1)
        # [num_example, 1, num_spec]

        loss = (C_mat * output).sum(-1) - rhs_mat + arguments.Config["attack"]["attack_tolerance"]
        loss = torch.clamp(loss, min=threshold)
        loss = -loss

    if origin_out is not None:
        loss_gamma = loss.sum() + (gama_lambda * (output - origin_out)**2).sum(dim=3).sum()
    else:
        loss_gamma = loss.sum()
    # [num_example, num_restarts, num_or_spec, num_and_spec]

    if mode == "sum":
        loss[loss >= 0] = 1.0
    # loss is returned for best loss selection, loss_gamma is for gradient descent.
    return loss, loss_gamma


def test_conditions(input, output, C_mat, rhs_mat, cond_mat, same_number_const, data_max, data_min, return_success_idx=False):
    '''
    Whether the output satisfies the specifiction conditions.
    If the output satisfies the specification for adversarial examples, this function returns True, otherwise False.

    input: [num_exampele, num_restarts, num_or_spec, *input_shape]
    output: [num_example, num_restarts, num_or_spec, num_output]
    C_mat: [num_example, num_restarts, num_spec, num_output] or [num_example, num_spec, num_output]
    rhs_mat: [num_example, num_spec]
    cond_mat: [[]] * num_examples
    same_number_const (bool): if same_number_const is True, it means that there are same number of and specifications in every or specification group.
    data_max & data_min: [num_example, num_spec, *input_shape]
    '''
    if same_number_const:
        C_mat = C_mat.view(C_mat.shape[0], 1, len(cond_mat[0]), -1, C_mat.shape[-1])
        # [batch_size, restarts, num_or_spec, num_and_spec, output_dim]
        rhs_mat = rhs_mat.view(rhs_mat.shape[0], 1, len(cond_mat[0]), -1)

        # apply a small tolerance to rhs so that we are more confident about the adv example
        cond = torch.matmul(C_mat, output.unsqueeze(-1)).squeeze(-1) - rhs_mat + arguments.Config["attack"]["attack_tolerance"]

        valid = ((input <= data_max) & (input >= data_min))

        valid = valid.reshape(*valid.shape[:3], -1)
        # [num_example, restarts, num_all_spec, output_dim]
        valid = valid.all(-1).view(valid.shape[0], valid.shape[1], len(cond_mat[0]), -1)
        # [num_example, restarts, num_or_spec, num_and_spec]

        res = ((cond.amax(dim=-1, keepdim=True) < 0.0) & valid).any(dim=-1).any(dim=-1).any(dim=-1)

        if res.all() and return_success_idx:
            # invalid examples will not be selected by torch.min, shape: [num_example, restarts, num_all_spec, output_dim]
            vio_value = cond.amax(dim=-1, keepdim=True) * valid
            # index of the adv example with the largest violation
            idx = int(torch.min(torch.min(vio_value, dim=1).values, dim=1).indices)
            return res, idx

    else:
        output = output.repeat_interleave(torch.tensor(cond_mat[0]).to(output.device), dim=2)
        # [num_example, num_restarts, num_spec, num_output]

        C_mat = C_mat.view(C_mat.shape[0], 1, -1, C_mat.shape[-1])
        # [num_example, 1, num_spec, num_output]
        rhs_mat = rhs_mat.view(rhs_mat.shape[0], 1, -1)
        # [num_example, 1, num_spec]

        # apply a small tolerance to rhs so that we are more confident about the adv example
        cond = torch.clamp((C_mat * output).sum(-1) - rhs_mat + arguments.Config["attack"]["attack_tolerance"], min=0.0)
        # [num_example, 1, num_spec]

        group_C = torch.zeros(len(cond_mat[0]), C_mat.shape[2], device=cond.device) # [num_or_spec, num_total_spec]
        x_index = []
        y_index = []
        index = 0

        for i, num_cond in enumerate(cond_mat[0]):
            x_index.extend([i] * num_cond)
            y_index.extend([index+j] for j in range(num_cond))
            index += num_cond

        group_C[x_index, y_index] = 1.0

        # loss shape: [batch_size, num_restarts, num_total_spec]
        cond = group_C.matmul(cond.unsqueeze(-1)).squeeze(-1)

        valid = ((input <= data_max) & (input >= data_min))
        valid = valid.view(*valid.shape[:3], -1)
        # [num_example, restarts, num_all_spec, output_dim]
        valid = valid.all(-1).view(valid.shape[0], valid.shape[1], len(cond_mat[0]), -1)
        # [num_example, restarts, num_or_spec, num_and_spec]

        valid = valid.all(-1)

        # [num_example, num_restarts, num_or_example]
        res = ((cond == 0.0) & valid).any(dim=-1).any(dim=-1)

        if res and return_success_idx:
            # invalid examples will not be selected by torch.min, shape: [num_example, restarts, num_all_spec]
            vio_value = ((C_mat * output).sum(-1) - rhs_mat) * (cond == 0.0) * valid
            # index of the adv example with the largest violation
            idx = int(torch.min(torch.min(vio_value, dim=1).values, dim=1).indices)
            return res, idx

    if return_success_idx:
        # just return a dummy index, won't be used
        return res, float('nan')

    return res


def default_early_stop_condition(inputs, output, C_mat, rhs_mat, cond_mat, same_number_const,
    data_max, data_min, model, indices, num_or_spec, return_success_idx=False):

    return test_conditions(inputs, output, C_mat, rhs_mat, cond_mat, same_number_const,
        data_max, data_min, return_success_idx)


def build_conditions(x, list_target_label_arrays):
    '''
    parse C_mat, rhs_mat from the target_label_arrays
    '''
    batch_size = x.shape[0]

    cond_mat = [[] for _ in range(batch_size)]
    C_mat = [[] for _ in range(batch_size)]
    rhs_mat = [[] for _ in range(batch_size)]

    same_number_const = True
    const_num = None
    for i in range(batch_size):
        target_label_arrays = list_target_label_arrays[i]
        for prop_mat, prop_rhs in target_label_arrays:
            C_mat[i].append(torch.Tensor(prop_mat).to(x.device))
            rhs_mat[i].append(torch.Tensor(prop_rhs).to(x.device))
            cond_mat[i].append(prop_rhs.shape[0]) # mark the `and` group
            if const_num is not None and prop_rhs.shape[0] != const_num:
                same_number_const = False
            else:
                const_num = prop_rhs.shape[0]

        C_mat[i] = torch.cat(C_mat[i], dim=0).unsqueeze(0)
        rhs_mat[i] = torch.cat(rhs_mat[i], dim=0).unsqueeze(0)

        # C: [1, num_spec, num_output]
    try:
        # try to stack the specs for a batch of examples
        # C: [num_example, num_spec, num_output]
        C_mat = torch.cat(C_mat, dim=0)
        rhs_mat = torch.cat(rhs_mat, dim=0)
    except (RuntimeError, ValueError):
        # failed when the examples have different number of specs
        print("Only support batches when the examples have the same number of constraints.")
        assert False
    # C shape: [num_example, num_spec, num_output]
    # rhs shape: [num_example, num_spec]
    # cond_mat shape: [num_example, num_spec]

    return C_mat, rhs_mat, cond_mat, same_number_const


def default_adv_example_finalizer(model_ori, x, best_deltas, data_max, data_min, C_mat, rhs_mat, cond_mat):
    # x and best_deltas has shape (batch, c, h, w).
    # data_min and data_max have shape (batch, spec, c, h, w).
    attack_image = torch.max(torch.min((x + best_deltas).unsqueeze(1), data_max), data_min)
    assert (attack_image >= data_min).all()
    assert (attack_image <= data_max).all()

    attack_output = model_ori(attack_image.view(-1, *x.shape[1:])).view(*attack_image.shape[:2], -1)
    # [batch_size, num_or_spec, out_dim]

    if arguments.Config['general']['save_output']:
        arguments.Globals['out']['pred_adv'] = attack_output[0][0].cpu()

    # only print out the first two random start outputs of the first two examples.
    print("Adv example prediction (first 2 examples and 2 restarts):\n", attack_output[:2,:2])

    attack_output_repeat = attack_output.unsqueeze(1).repeat_interleave(torch.tensor(cond_mat[0]).to(x.device), dim=2)
    # [num_example, num_restarts, num_spec, num_output]

    C_mat = C_mat.view(C_mat.shape[0], 1, -1, C_mat.shape[-1])
    # [num_example, 1, num_spec, num_output]
    rhs_mat = rhs_mat.view(rhs_mat.shape[0], 1, -1)
    # [num_example, 1, num_spec]

    attack_margin = (C_mat * attack_output_repeat).sum(-1) - rhs_mat
    # [num_example, num_restarts, num_spec]

    if arguments.Config['general']['save_output']:
        arguments.Globals['out']['attack_margin'] = attack_margin.cpu()

    print("PGD attack margin (first 2 examles and 10 specs):\n", attack_margin[:2, :, :10])
    print("number of violation: ", (attack_margin < 0).sum().item())
    # print the first 10 specifications for the first 2 examples

    return attack_image, attack_output, attack_margin

def OSI_init_C(model, X, alpha, output_dim, iter_steps=50, lower_limit=0.0, upper_limit=1.0):
    # the general version of OSI initialization.
    input_shape = X.shape
    # [batch_size, num_restarts, num_or_spec, *X_shape[1:]]
    X_init = X.clone().detach()
    # [batch_size, num_restarts, num_or_spec, *X_shape[1:]]
    X_init = X_init.view(-1, *X_init.shape[3:])
    X = X.view(-1, *X.shape[3:])
    # [batch_size, * num_restarts * num_or_spec, *X_shape[1:]]

    w_d = (torch.rand([X.shape[0], output_dim], device=X.device) - 0.5) * 2

    for i in range(iter_steps):
        X_init = X_init.detach().requires_grad_()
        output = model(X_init)

        # test whether we need to early stop here.

        dot = torch.einsum('...,...->', w_d, output)
        # dot = (w_d * output).sum()
        dot.backward()

        with torch.no_grad():
            X_init = X_init + alpha * torch.sign(X_init.grad)
            X_init = X_init.view(input_shape)
            X_init = torch.max(torch.min(X_init, upper_limit), lower_limit)
            X_init = X_init.view(-1, *X_init.shape[3:])

    X_init = X_init.view(input_shape)
    X = X.view(input_shape)

    assert (X_init <= upper_limit).all()
    assert (X_init >= lower_limit).all()

    return X_init



def OSI_init(model, X, y, eps, alpha, num_classes, iter_steps=50, lower_limit=0.0, upper_limit=1.0, extra_dim=None):
    input_shape = X.size()
    if extra_dim is not None:
        X = X.unsqueeze(1).unsqueeze(1).expand(-1, *extra_dim, *(-1,) * (X.ndim - 1))
    expand_shape = X.size()

    X_init = X.clone().detach()

    upper_limit = upper_limit.unsqueeze(1).unsqueeze(2)
    lower_limit = lower_limit.unsqueeze(1).unsqueeze(2)
    delta = (torch.empty_like(X).uniform_(-1, 1) * (upper_limit - lower_limit) + lower_limit)
    X_init = X_init + delta

    X = X.reshape(-1, *input_shape[1:])
    X_init = X_init.reshape(-1, *input_shape[1:])
    # Random vector from [-1, 1].
    w_d = (torch.rand([X.shape[0], num_classes], device=X.device) - 0.5) * 2

    if eps != float('inf'):
        lower_limit = torch.clamp(X-eps, min=lower_limit)
        upper_limit = torch.clamp(X+eps, max=upper_limit)

    for i in range(iter_steps):
        X_init = X_init.detach().requires_grad_()
        output = model(X_init)

        if (output.argmax(-1) != y).any():
            # return if attack succeeds.
            return X_init.view(expand_shape)

        dot = (w_d * output).sum()
        grad = torch.autograd.grad(dot, X_init)[0]

        X_init = X_init + alpha * torch.sign(grad)

        X_init = X_init.view(expand_shape)
        X_init = torch.max(torch.min(X_init, upper_limit), lower_limit)
        X = X.view(expand_shape)
        X_init = torch.max(torch.min(X_init, X+eps), X-eps)
        X_init = X_init.reshape(-1, *input_shape[1:])
        X = X.reshape(-1, *input_shape[1:])

    X_init = X_init.view(expand_shape)
    X = X.view(expand_shape)

    assert (X_init <= upper_limit).all()
    assert (X_init >= lower_limit).all()

    if eps is not None:
        assert (X_init <= X+eps).all()
        assert (X_init >= X-eps).all()

    return X_init

def pgd_attack_with_general_specs(model, X, data_min, data_max, C_mat, rhs_mat,
                                  cond_mat, same_number_const, alpha,
                                  use_adam=True, normalize=lambda x: x,
                                  initialization='uniform', GAMA_loss=False,
                                  num_restarts=None, pgd_steps=None,
                                  only_replicate_restarts=False,
                                  return_early_stopped=False):

    r''' the functional function for pgd attack

    Args:
        model (torch.nn.Module): PyTorch module under attack.

        x (torch.tensor): Input image (x_0).

        data_min (torch.tensor): Lower bounds of data input. (e.g., 0 for mnist)

        data_max (torch.tensor): Lower bounds of data input. (e.g., 1 for mnist)

        C_mat (torch.tensor): [num_example, num_spec, num_output]

        rhs_mat (torch.tensor): [num_example, num_spec]

        cond_mat (list): [[] * num_example] mark the group of conditions

        same_number_const (bool): if same_number_const is True, it means that there are same number of and specifications in every or specification group.

        alpha (float): alpha for pgd attack
    '''
    device = X.device
    attack_iters = arguments.Config["attack"]["pgd_steps"] if pgd_steps is None else pgd_steps
    num_restarts = arguments.Config["attack"]["pgd_restarts"] if num_restarts is None else num_restarts

    lr_decay=arguments.Config["attack"]["pgd_lr_decay"]
    early_stop=arguments.Config["attack"]["pgd_early_stop"]

    if only_replicate_restarts:
        input_shape = (X.shape[0], *X.shape[2:])
    else:
        input_shape = X.size()
    num_classes = C_mat.shape[-1]

    num_or_spec = len(cond_mat[0])

    extra_dim = (num_restarts, num_or_spec) if only_replicate_restarts == False else (num_restarts,)
    # shape of x: [num_example, *shape_of_x]

    best_loss = torch.empty(X.size(0), device=device).fill_(float("-inf"))
    best_delta = torch.zeros(input_shape, device=device)

    data_min = data_min.unsqueeze(1)
    data_max = data_max.unsqueeze(1)
    # [1, 1, num_spec, *input_shape]

    X_ndim = X.ndim

    X = X.view(X.shape[0], *[1] * len(extra_dim), *X.shape[1:])
    delta_lower_limit = data_min - X
    delta_upper_limit = data_max - X

    X = X.expand(-1, *extra_dim, *(-1,) * (X_ndim - 1))
    extra_dim = (X.shape[1], X.shape[2])

    if initialization == 'osi':
        # X_init = OSI_init(model, X, y, epsilon, alpha, num_classes, iter_steps=attack_iters, extra_dim=extra_dim, upper_limit=upper_limit, lower_limit=lower_limit)
        osi_start_time = time.time()
        X_init = OSI_init_C(model, X, alpha, C_mat.shape[-1], attack_iters, data_min, data_max)
        osi_time = time.time() - osi_start_time
        print(f'diversed PGD initialization time: {osi_time:.4f}')
    if initialization == 'boundary':
        boundary_adv_examples = boundary_attack(model, X[:,0,...].view(-1, *input_shape[1:]), data_min.view(*input_shape), data_max.view(*input_shape))
        if boundary_adv_examples is not None:
            X_init = boundary_adv_examples.view(X.shape[0], -1, *X.shape[2:])
            X = X[:,:X_init.shape[1],...]
            extra_dim = (X.shape[1], X.shape[2])
        else:
            initialization = 'uniform'

    gama_lambda = arguments.Config["attack"]["gama_lambda"]

    if initialization == 'osi' or initialization == 'boundary':
        delta = (X_init - X).detach().requires_grad_()
    elif initialization == 'uniform':
        delta = (torch.empty_like(X).uniform_() * (delta_upper_limit - delta_lower_limit) + delta_lower_limit).requires_grad_()
    elif initialization == 'none':
        delta = torch.zeros_like(X).requires_grad_()
    else:
        raise ValueError(f"Unknown initialization method {initialization}")

    if use_adam:
        opt = AdamClipping(params=[delta], lr=alpha)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, lr_decay)

    early_stopped = False

    for _ in range(attack_iters):
        inputs = normalize(X + delta)
        output = model(inputs.view(-1, *input_shape[1:])).view(
            input_shape[0], *extra_dim, num_classes)

        if GAMA_loss:
            # Output on original model is needed if gama loss is used.
            origin_out = torch.softmax(model(normalize(X.reshape(-1, *input_shape[1:]))), 1)
            origin_out = origin_out.view(output.shape)
        else:
            origin_out = None

        loss, loss_gama = eval(arguments.Config["attack"]["pgd_loss"])(
            origin_out, output, C_mat, rhs_mat,
            cond_mat, same_number_const,
            gama_lambda if GAMA_loss else 0.0,
            mode=arguments.Config['attack']['pgd_loss_mode'], model=model)
        gama_lambda *= arguments.Config["attack"]["gama_decay"]
        # shape of loss: [num_example, num_restarts, num_or_spec]
        # or float when gama_lambda > 0

        loss_gama.sum().backward()

        with torch.no_grad():
            # Save the best loss so far.
            if same_number_const:
                loss = loss.amin(-1)
                # loss has shape [num_example, num_restarts, num_or_spec].
                # margins = (runnerup - groundtruth).view(groundtruth.size(0), -1)
            else:
                group_C = torch.zeros(len(cond_mat[0]), C_mat.shape[1]).to(loss.device) # [num_or_spec, num_total_spec]
                x_index = []
                y_index = []
                index = 0
                for i, cond in enumerate(cond_mat[0]):
                    for _ in range(cond):
                        x_index.append(i)
                        y_index.append(index)
                        index += 1
                group_C[x_index, y_index] = 1.0

                # loss shape: [batch_size, num_restarts, num_total_spec]
                loss = group_C.matmul(loss.unsqueeze(-1)).squeeze(-1)
                # loss shape: [batch_size, num_restarts, num_or_spec]

            loss = loss.view(loss.shape[0], -1)
            # all_loss and indices have shape (batch, ),
            # and this is the best loss over all restarts and number of classes.
            all_loss, indices = loss.max(1)
            # delta has shape (batch, restarts, num_class-1, c, h, w).
            # For each batch element, we want to select from the best over
            # (restarts, num_classes-1) dimension.
            # delta_targeted has shape (batch, c, h, w).
            delta_targeted = delta.view(
                delta.size(0), -1, *input_shape[1:]
            ).gather(
                dim=1, index=indices.view(
                    -1,1,*(1,) * (len(input_shape) - 1)).expand(
                        -1,-1,*input_shape[1:])
            ).squeeze(1)

            best_delta[all_loss >= best_loss] = delta_targeted[all_loss >= best_loss]
            best_loss = torch.max(best_loss, all_loss)

        if early_stop:
            if eval(arguments.Config["attack"]["early_stop_condition"])(inputs, output, C_mat, rhs_mat,
                    cond_mat, same_number_const, data_max, data_min, model, indices, num_or_spec).all():
                print("pgd early stop")
                early_stopped = True
                break

        if use_adam:
            opt.step(clipping=True, lower_limit=delta_lower_limit,
                     upper_limit=delta_upper_limit, sign=1)
            opt.zero_grad(set_to_none=True)
            scheduler.step()
        else:
            d = delta + alpha * torch.sign(delta.grad)
            d = torch.max(torch.min(d, delta_upper_limit), delta_lower_limit)
            delta = d.detach().requires_grad_()

    if not early_stopped and 'Customized' in arguments.Config["attack"]["early_stop_condition"]:
        test_input = X[:, 0, 0, :] + best_delta
        test_output = model(test_input)
        test_input = test_input.unsqueeze(0).unsqueeze(0)
        test_output = test_output.unsqueeze(0).unsqueeze(0)
        if not test_conditions(test_input, test_output, C_mat, rhs_mat, cond_mat,
                           same_number_const, data_max, data_min).all():
            best_loss = torch.full(size=(1,), fill_value=float('-inf'), device=best_loss.device)

    if return_early_stopped:
        return best_delta, delta, best_loss, early_stopped
    else:
        return best_delta, delta, best_loss


def attack_pgd(model, X, y, epsilon, alpha, attack_iters, num_restarts,
        multi_targeted=True, num_classes=10, use_adam=True, lr_decay=0.98,
        lower_limit=0.0, upper_limit=1.0, normalize=lambda x: x, early_stop=True, target=None,
        initialization='uniform', GAMA_loss=False, nn4sys=False):
    if initialization == 'osi':
        if multi_targeted:
            extra_dim = (num_restarts, num_classes - 1,)
        else:
            extra_dim = (num_restarts)
        X_init = OSI_init(model, X, y, epsilon, alpha, num_classes,
                          iter_steps=attack_iters, extra_dim=extra_dim,
                          upper_limit=upper_limit, lower_limit=lower_limit)

    best_loss = torch.empty(X.size(0), device=X.device).fill_(float("-inf"))
    best_delta = torch.zeros_like(X, device=X.device)

    input_shape = X.size()
    if multi_targeted:
        assert target is None  # Multi-targeted attack is for non-targed attack only.
        extra_dim = (num_restarts, num_classes - 1,)
        # Add two extra dimensions for targets. Shape is (batch, restarts, target, ...).
        X = X.unsqueeze(1).unsqueeze(1).expand(-1, *extra_dim, *(-1,) * (X.ndim - 1))
        # Generate target label list for each example.
        E = torch.eye(num_classes, dtype=X.dtype, device=X.device)
        c = E.unsqueeze(0) - E[y].unsqueeze(1)
        # remove specifications to self.
        I = ~(y.unsqueeze(1) == torch.arange(num_classes, device=y.device).unsqueeze(0))
        # c has shape (batch, num_classes - 1, num_classes).
        c = c[I].view(input_shape[0], num_classes - 1, num_classes)
        # c has shape (batch, restarts, num_classes - 1, num_classes).
        c = c.unsqueeze(1).expand(-1, num_restarts, -1, -1)
        target_y = y.view(-1,*(1,) * len(extra_dim),1).expand(-1, *extra_dim, 1)
        # Restart is processed in a batch and no need to do individual restarts.
        num_restarts = 1
        # If element-wise lower and upper limits are given, we should reshape them to the same as X.
        if lower_limit.ndim == len(input_shape):
            lower_limit = lower_limit.unsqueeze(1).unsqueeze(1)
        if upper_limit.ndim == len(input_shape):
            upper_limit = upper_limit.unsqueeze(1).unsqueeze(1)
    else:
        if target is not None:
            # An attack target for targeted attack, in dimension (batch, ).
            target = torch.tensor(target, device='cuda').view(-1,1)
            target_index = target.view(-1,1,1).expand(-1, num_restarts, 1)
            # Add an extra dimension for num_restarts. Shape is (batch, num_restarts, ...).
            X = X.unsqueeze(1).expand(-1, num_restarts, *(-1,) * (X.ndim - 1))
            # Only run 1 restart, since we run all restarts together.
            extra_dim = (num_restarts, )
            num_restarts = 1
            # If element-wise lower and upper limits are given, we should reshape them to the same as X.
            if lower_limit.ndim == len(input_shape):
                lower_limit = lower_limit.unsqueeze(1)
            if upper_limit.ndim == len(input_shape):
                upper_limit = upper_limit.unsqueeze(1)
        elif nn4sys: # nn4sys
            # Add an extra dimension for num_restarts. Shape is (batch, num_restarts, ...).
            X = X.unsqueeze(1).expand(-1, num_restarts, *(-1,) * (X.ndim - 1))
            extra_dim = (num_restarts, )
            num_restarts = 1
            # If element-wise lower and upper limits are given, we should reshape them to the same as X.
            if lower_limit.ndim == len(input_shape):
                lower_limit = lower_limit.unsqueeze(1)
            if upper_limit.ndim == len(input_shape):
                upper_limit = upper_limit.unsqueeze(1)

    # This is the maximal/minimal delta values for each sample, each element.
    sample_lower_limit = torch.clamp(lower_limit - X, min=-epsilon)
    sample_upper_limit = torch.clamp(upper_limit - X, max=epsilon)

    success = False

    for _ in range(num_restarts):
        # one_label_loss = True if n % 2 == 0 else False  # random select target loss or marginal loss
        one_label_loss = False  # Temporarily disabled. Will do more tests.
        gama_lambda = arguments.Config["attack"]["gama_lambda"]
        if early_stop and success:
            break

        if initialization == 'osi':
            delta = (X_init - X).detach().requires_grad_()
        elif initialization == 'uniform':
            delta = (torch.empty_like(X).uniform_() * (sample_upper_limit - sample_lower_limit) + sample_lower_limit).requires_grad_()
        elif initialization == 'none':
            delta = torch.zeros_like(X).requires_grad_()
        else:
            raise ValueError(f"Unknown initialization method {initialization}")

        if use_adam:
            opt = AdamClipping(params=[delta], lr=alpha)
            scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, lr_decay)

        for _ in range(attack_iters):
            inputs = normalize(X + delta)
            if multi_targeted or target is not None:
                output = model(inputs.view(-1, *input_shape[1:])).view(input_shape[0], *extra_dim, num_classes)
            else:
                output = model(inputs)

            if not multi_targeted:
                # Not using the multi-targeted loss. Can be target or non-target attack.
                if nn4sys:
                    # [batch, num_restarts, num_class] -> [batch, num_restarts]
                    loss = (-y*output).sum(-1)
                elif target is not None:
                    # Targeted attack. In this case we have an extra (num_starts, ) dimension.
                    runnerup = output.scatter(dim=2, index=target_index, value=-float("inf")).max(2).values
                    # t = output.gather(dim=2, index=target_index).squeeze(-1)
                    t = output[:, :, target].squeeze(-1).squeeze(-1)
                    loss = (t - runnerup)
                else:
                    # Non-targeted attack.
                    if one_label_loss:
                        # Loss 1: simply reduce the loss groundtruth target.
                        loss = -output.gather(dim=1, index=y.view(-1,1)).squeeze(1)  # -groundtruth
                    else:
                        # Loss 2: reduce the margin between groundtruth and runner up label.
                        runnerup = output.scatter(dim=1, index=y.view(-1,1), value=-100.0).max(1).values
                        groundtruth = output.gather(dim=1, index=y.view(-1,1)).squeeze(1)
                        # Use the margin as the loss function.
                        loss = (runnerup - groundtruth)
                loss.sum().backward()
            else:
                if GAMA_loss:
                    origin_out = torch.softmax(model(normalize(X.reshape(-1, *input_shape[1:]))), 1)
                    origin_out = origin_out.view(output.shape)

                    # [batch, restarts, label - 1, logits]
                    out = torch.softmax(output, -1)

                    # runnerup = out.scatter(dim=3, index=target_y, value=-float("inf")).max(3).values
                    # groundtruth = out.gather(dim=3, index=target_y).squeeze(3)
                    # Use the margin as the loss function.
                    loss = torch.einsum('ijkl,ijkl->', c, output)
                    loss = loss.sum() + (gama_lambda * (out - origin_out)**2).sum(dim=3).sum()
                    # and_idx = np.arange(len(constraints)).repeat(np.floor(batch_size / len(constraints)))
                    # and_idx = torch.tensor(np.concatenate([and_idx, np.arange(batch_size - len(and_idx))], axis=0)).to(device)
                    # out = torch.softmax(out,1)
                    # loss = (constraint_loss(out, constraints, and_idx=and_idx) + (gama_lambda * (out_X-out)**2).sum(dim=1)).sum()
                    gama_lambda *= arguments.Config["attack"]["gama_decay"]
                else:
                    # Non-targeted attack, using margins between groundtruth class and all target classes together.
                    # loss = torch.einsum('ijkl,ijkl->ijk', c, output)
                    loss = torch.einsum('ijkl,ijkl->', c, output)
                loss.backward()

            # print(loss.sum().item(), output.detach().cpu().numpy())
            # print(loss[:, :, 5])

            with torch.no_grad():
                # Save the best loss so far.
                if not multi_targeted:
                    # Not using multi-targeted loss.
                    if nn4sys:
                        all_loss, indices = loss.max(1)
                        delta_best = delta.gather(dim=1, index=indices.view(-1,1,*[1]*len(input_shape[1:])).expand(-1,-1,*input_shape[1:])).squeeze(1)
                        best_delta[all_loss >= best_loss] = delta_best[all_loss >= best_loss]
                        best_loss = torch.max(best_loss, all_loss)
                    elif target is not None:
                        # Targeted attack, need to check if the top-1 label is target label.
                        # Since we merged the random restart dimension, we need to find the best one among all random restarts.
                        all_loss, indices = loss.max(1)
                        # Gather the delta for the best loss in all random restarts.
                        delta_best = delta.gather(dim=1, index=indices.view(-1,1,1,1,1).expand(-1,-1,*input_shape[1:])).squeeze(1)
                        best_delta[all_loss >= best_loss] = delta_best[all_loss >= best_loss]
                        best_loss = torch.max(best_loss, all_loss)
                    else:
                        # Non-targeted attack. Success when the groundtruth is not top-1.
                        if one_label_loss:
                            runnerup = output.scatter(dim=1, index=y.view(-1,1), value=-100.0).max(1).values
                            groundtruth = output.gather(dim=1, index=y.view(-1,1)).squeeze(1)
                            # Use the margin as the loss function.
                            criterion = (runnerup - groundtruth)  # larger is better.
                        else:
                            criterion = loss
                        # Larger is better.
                        best_delta[criterion >= best_loss] = delta[criterion >= best_loss]
                        best_loss = torch.max(best_loss, criterion)
                else:
                    # if GAMA_loss:
                    #     # out = torch.softmax(output, -1).view(-1, num_classes)
                    #     out = torch.softmax(output, -1)
                    #     # runnerup = out.scatter(dim=3, index=target_y, value=-float("inf")).max(3).values
                    #     # groundtruth = out.gather(dim=3, index=target_y).squeeze(3)

                    #     # Use the margin as the loss function.
                    #     loss = torch.einsum('ijkl,ijkl->', c, output) + (gama_lambda/0.9 * (out - origin_out.view(output.shape))**2).sum(dim=3)
                    #     loss = loss.view(out.size(0), -1)
                    #     all_loss, indices = loss.max(1)
                    #     # and_idx = np.arange(len(constraints)).repeat(np.floor(batch_size / len(constraints)))
                    #     # and_idx = torch.tensor(np.concatenate([and_idx, np.arange(batch_size - len(and_idx))], axis=0)).to(device)
                    #     # out = torch.softmax(out,1)
                    #     # loss = (constraint_loss(out, constraints, and_idx=and_idx) + (gama_lambda * (out_X-out)**2).sum(dim=1)).sum()
                    #     delta_targeted = delta.view(delta.size(0), -1, *input_shape[1:]).gather(dim=1, index=indices.view(-1,1,*(1,) * (len(input_shape) - 1)).expand(-1,-1,*input_shape[1:])).squeeze(1)
                    #     best_delta[all_loss >= best_loss] = delta_targeted[all_loss >= best_loss]
                    #     best_loss = torch.max(best_loss, all_loss)
                    # else:
                    # Using multi-targeted loss. Need to find which label causes the worst case margin.
                    # Keep the one with largest margin.
                    # Note that we recompute the runnerup label here - the runnerup label might not be the target label.
                    # output has shape (batch, restarts, num_classes-1, num_classes).
                    # runnerup has shape (batch, restarts, num_classes-1).
                    runnerup = output.scatter(dim=3, index=target_y, value=-float("inf")).max(3).values
                    # groundtruth has shape (batch, restarts, num_classes-1).
                    groundtruth = output.gather(dim=3, index=target_y).squeeze(-1)
                    # margins has shape (batch, restarts * num_classes), ).
                    margins = (runnerup - groundtruth).view(groundtruth.size(0), -1)
                    # all_loss and indices have shape (batch, ), and this is the best loss over all restarts and number of classes.
                    all_loss, indices = margins.max(1)
                    # delta has shape (batch, restarts, num_class-1, c, h, w). For each batch element, we want to select from the best over (restarts, num_classes-1) dimension.
                    # delta_targeted has shape (batch, c, h, w).
                    delta_targeted = delta.view(delta.size(0), -1, *input_shape[1:]).gather(dim=1, index=indices.view(-1,1,*(1,) * (len(input_shape) - 1)).expand(-1,-1,*input_shape[1:])).squeeze(1)
                    best_delta[all_loss >= best_loss] = delta_targeted[all_loss >= best_loss]
                    best_loss = torch.max(best_loss, all_loss)

                if early_stop:
                    if multi_targeted:
                        # Must be a untargeted attack. If any of the target succeed, that element in batch is successfully attacked.
                        # output has shape (batch, num_restarts, num_classes-1, num_classes,).
                        if (output.view(output.size(0), -1, num_classes).max(2).indices != y.unsqueeze(1)).any(1).all():
                            print('pgd early stop.')
                            success = True
                            break
                    elif target is not None:
                        # Targeted attack, the top-1 label of every element in batch must match the target.
                        # If any attack in some random restarts succeeds, the attack is successful.
                        if (output.max(2).indices == target).any(1).all():
                            print('pgd early stop.')
                            success = True
                            break
                    elif nn4sys:
                        if (output.view(output.size(0), -1, num_classes) * y <= target).any():
                            print('pgd early stop.')
                            success = True
                            break
                    else:
                        # Non-targeted attack, the top-1 label of every element in batch must not be the groundtruth label.
                        if (output.max(1).indices != y).all():
                            print('pgd early stop.')
                            success = True
                            break

                # Optimizer step.
                if use_adam:
                    opt.step(clipping=True, lower_limit=sample_lower_limit, upper_limit=sample_upper_limit, sign=1)
                    opt.zero_grad(set_to_none=True)
                    scheduler.step()
                else:
                    d = delta + alpha * torch.sign(delta.grad)
                    d = torch.max(torch.min(d, sample_upper_limit), sample_lower_limit)
                    delta.copy_(d)
                    delta.grad = None

    return best_delta, delta


class AdamClipping(Optimizer):
    r"""Implements Adam algorithm, with per-parameter gradient clipping.
    The function is from PyTorch source code.

    It has been proposed in `Adam: A Method for Stochastic Optimization`_.

    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
            algorithm from the paper `On the Convergence of Adam and Beyond`_
            (default: False)

    .. _Adam\: A Method for Stochastic Optimization:
        https://arxiv.org/abs/1412.6980
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, amsgrad=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad)
        super().__init__(params, defaults)

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)


    @staticmethod
    @torch.no_grad()
    @torch.jit.script
    def _clip_update(exp_avg : torch.Tensor, denom : torch.Tensor, step_size : float, clipping_step_eps : float, lower_limit : torch.Tensor, upper_limit : torch.Tensor, p : torch.Tensor):
        # Compute the Adam update.
        update = exp_avg / denom * step_size
        # update = p.grad
        # Linf norm, scale according to sign.
        scaled_update = torch.sign(update) * clipping_step_eps
        # Apply the update.
        d = p.data + scaled_update
        # Avoid out-of-boundary updates.
        d = torch.max(torch.min(d, upper_limit), lower_limit)
        p.copy_(d)

    @torch.no_grad()
    def step(self, clipping=None, lower_limit=None, upper_limit=None, sign=None, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        # Currently we only deal with 1 parameter group.
        assert len(self.param_groups) == 1
        for group in self.param_groups:
            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
                amsgrad = group['amsgrad']

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                if amsgrad:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

                if group['weight_decay'] != 0:
                    grad = grad.add(p, alpha=group['weight_decay'])

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                if amsgrad:
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    # Use the max. for normalizing running avg. of gradient
                    denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
                else:
                    denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

                step_size = group['lr'] / bias_correction1

                if clipping:
                    assert sign == 1  # gradient ascent for adversarial attacks.
                    self._clip_update(exp_avg, denom, step_size, step_size, lower_limit, upper_limit, p)
                else:
                    # No clipping. Original Adam update.
                    p.addcdiv_(exp_avg, denom, value=-step_size)

        return loss

def boundary_attack(model, x, data_min, data_max):
    perturbation_index = ((data_max - data_min) != 0).view(data_max.shape[0], -1).nonzero()
    # index of the pixels perturbed
    if len(perturbation_index) > 5:
        print("Error: number of perturbed pixels is larger than 5, boundary attack is disabled.")
        return None

    data_max_flatten = data_max.view(data_max.shape[0], -1)
    data_min_flatten = data_min.view(data_min.shape[0], -1)

    adv_example = data_max_flatten
    for idx in perturbation_index:
        adv_example_neg = adv_example.clone()
        adv_example[:, idx] = data_max_flatten[0, idx]
        adv_example_neg[:, idx] = data_min_flatten[0, idx]
        adv_example = torch.cat([adv_example, adv_example_neg], dim=0)

    return adv_example.view(-1, *data_max.shape[1:])




def attack_with_general_specs(model, x, data_min, data_max,
                              list_target_label_arrays,
                              initialization="uniform", GAMA_loss=False):
    r""" Interface to PGD attack.

    Args:
        model (torch.nn.Module): PyTorch module under attack.

        x (torch.tensor): Input image (x_0).
        [batch_size, *x_shape]

        data_min (torch.tensor): Lower bounds of data input. (e.g., 0 for mnist)
        shape: [batch_size, spec_num, *input_shape]

        data_max (torch.tensor): Lower bounds of data input. (e.g., 1 for mnist)
        shape: [batch_size, spec_num, *input_shape]

        list_target_label_arrays: a list of list of tuples:
                We have N examples, and list_target_label_arrays is a list containing N lists.
                Each inner list contains the target_label_array for an example:
                    [(prop_mat_1, prop_rhs_1), (prop_mat_2, prop_rhs_2), ..., (prop_mat_n, prop_rhs_n)]
                    prop_mat is a numpy array with shape [num_and, num_output], prop_rhs is a numpy array with shape [num_and]

        initialization (string): initialization of PGD attack, chosen from 'uniform' and 'osi'

        GAMA_loss (boolean): whether to use GAMA (Guided adversarial attack) loss in PGD attack
    """
    attack_start_time = time.time()
    assert arguments.Config["specification"]["norm"] == np.inf, print('We only support Linf-norm attack.')
    use_adam = True

    device = x.device

    alpha = arguments.Config["attack"]["pgd_alpha"]
    alpha_scale = arguments.Config["attack"]["pgd_alpha_scale"]
    if alpha_scale:
        alpha = (data_max - data_min) * float(alpha)
        use_adam = False
    else:
        if alpha == 'auto':
            max_eps = torch.max(data_max - data_min).item()/2
            alpha = max_eps / 4
        else:
            alpha = float(alpha)

    print(f'Attack parameters: initialization={initialization}, steps={arguments.Config["attack"]["pgd_steps"]}, restarts={arguments.Config["attack"]["pgd_restarts"]}, alpha={alpha}, initialization={initialization}, GAMA={GAMA_loss}')

    # Set all parameters without gradient, this can speedup things significantly.
    grad_status = {}
    for p in model.parameters():
        grad_status[p] = p.requires_grad
        p.requires_grad_(False)

    output = model(x).detach()

    # FIXME conflict with clean prediction
    # if arguments.Config['general']['save_output']:
    #     arguments.Globals['out']['pred'] = output.cpu()

    print('Model output of first 5 examples:\n', output[:5])

    C_mat, rhs_mat, cond_mat, same_number_const = build_conditions(x, list_target_label_arrays)

    output = output.unsqueeze(1).unsqueeze(1).repeat(1, 1, len(cond_mat[0]), 1)

    if test_conditions(x, output, C_mat, rhs_mat, cond_mat, same_number_const,
                       data_max.unsqueeze(1), data_min.unsqueeze(1)).all():
        print("Clean prediction incorrect, attack skipped.")
        # Obtain attack margin.
        attack_image, _, attack_margin = eval(arguments.Config["attack"]["adv_example_finalizer"])(model, x, torch.zeros_like(x), data_max, data_min, C_mat, rhs_mat, cond_mat)
        return True, attack_image.detach(), attack_margin.detach(), None

    data_min = data_min.to(device)
    data_max = data_max.to(device)
    rhs_mat = rhs_mat.to(device)
    C_mat = C_mat.to(device)
    num_restarts = arguments.Config["attack"]["pgd_restarts"]
    batch_size = arguments.Config["attack"]["pgd_batch_size"]
    best_deltas = None
    best_loss = None
    for _ in tqdm(range((num_restarts + batch_size - 1) // batch_size)):
        best_deltas_, last_deltas, best_loss_, early_stopped = pgd_attack_with_general_specs(
            model, x, data_min, data_max, C_mat, rhs_mat, cond_mat, same_number_const, alpha,
            initialization=initialization, GAMA_loss=GAMA_loss,
            use_adam=use_adam, num_restarts=min(batch_size, num_restarts), return_early_stopped=True)
        num_restarts -= batch_size
        if best_deltas is None:
            best_deltas = best_deltas_
            best_loss = best_loss_
        else:
            best_deltas[best_loss_ >= best_loss] = best_deltas_[
                best_loss_ >= best_loss]
            best_loss = torch.max(best_loss, best_loss_)
        if early_stopped:
            break

    attack_image, attack_output, attack_margin = eval(arguments.Config["attack"]["adv_example_finalizer"])(
        model, x, best_deltas, data_max, data_min, C_mat, rhs_mat, cond_mat)

    # Adversarial images/candidates in all restarts and targets. Useful for BaB-attack.
    # last_deltas has shape [batch, num_restarts, specs, c, h, w]. Need the extra num_restarts and specs dim.
    # x has shape [batch, c, h, w] and data_min/data_max has shape [batch, num_specs, c, h, w].
    all_adv_candidates = torch.max(
            torch.min(x.unsqueeze(1).unsqueeze(1) + last_deltas,
                data_max.unsqueeze(1)), data_min.unsqueeze(1))

    # Go back to original requires_grad status.
    for p in model.parameters():
        p.requires_grad_(grad_status[p])

    attack_time = time.time() - attack_start_time
    print(f'Attack finished in {attack_time:.4f} seconds.')
    if test_conditions(attack_image.unsqueeze(1), attack_output.unsqueeze(1),
                       C_mat, rhs_mat, cond_mat, same_number_const,
                       data_max, data_min).all():
        print("PGD attack succeeded!")
        return True, attack_image.detach(), attack_margin.detach(), all_adv_candidates
    else:
        print("PGD attack failed")
        return False, attack_image.detach(), attack_margin.detach(), all_adv_candidates


def pgd_attack(dataset, model, x, max_eps, data_min, data_max, vnnlib=None, y=None,
               target=None, only_target_attack=False, initialization="uniform", GAMA_loss=False):
    # FIXME (01/11/2022): any parameter that can be read from config should not be passed in.
    r"""Interface to PGD attack.

    Args:
        dataset (str): The name of dataset. Each dataset might have different attack configurations.

        model (torch.nn.Module): PyTorch module under attack.

        x (torch.tensor): Input image (x_0).

        max_eps (float): Perturbation Epsilon. Assuming Linf perturbation for now. (e.g., 0.3 for MNIST)

        data_min (torch.tensor): Lower bounds of data input. (e.g., 0 for mnist)

        data_max (torch.tensor): Lower bounds of data input. (e.g., 1 for mnist)

        vnnlib (list, optional): VNNLIB specifications. It will be used to extract attack target.

        y (int, optional): Groundtruth label. If specified, vnnlib will be ignored.

    Returns:
        success (bool): True if attack is successful. Otherwise False.
        attack_images (torch.tensor): last adversarial examples so far, may not be a real adversarial example if attack failed
    """
    assert arguments.Config["specification"]["norm"] == np.inf, print('We only support Linf-norm attack.')
    if dataset in ["MNIST", "CIFAR", "UNKNOWN"]:  # FIXME (01/11/2022): Make the attack function generic, not for the two datasets only!
        # FIXME (01/11/2022): Generic specification PGD.
        if y is not None and target is None:
            # Use y as the groundtruth label.
            pidx_list = ["all"]
            if arguments.Config["specification"]["type"] == "lp":
                if data_max is None:
                    data_max = x + max_eps
                    data_min = x - max_eps
                else:
                    data_max = torch.min(x + max_eps, data_max)
                    data_min = torch.max(x - max_eps, data_min)
            # If arguments.Config["specification"]["type"] == "bound", then we keep data_min and data_max.
        else:
            pidx_list = []
            if vnnlib is not None:
                # Extract attack target from vnnlib.
                for prop_mat, prop_rhs in vnnlib[0][1]:
                    if len(prop_rhs) > 1:
                        output = model(x).detach().cpu().numpy().flatten()
                        print('model output:', output)
                        vec = prop_mat.dot(output)
                        selected_prop = prop_mat[vec.argmax()]
                        y = int(np.where(selected_prop == 1)[0])  # true label, whatever in target attack
                        pidx = int(np.where(selected_prop == -1)[0])  # target label
                        only_target_attack = True
                    else:
                        assert len(prop_mat) == 1
                        y = np.where(prop_mat[0] == 1)[0]
                        if len(y) != 0:
                            y = int(y)
                        else:
                            y = None
                        pidx = int(np.where(prop_mat[0] == -1)[0])  # target label
                    if pidx == y:
                        raise NotImplementedError
                    pidx_list.append(pidx)
            elif target is not None:
                pidx_list.append(target)
            else:
                raise NotImplementedError

        print('##### PGD attack: True label: {}, Tested against: {} ######'.format(y, pidx_list))
        if not isinstance(max_eps, float):
            max_eps = torch.max(max_eps).item()

        if only_target_attack:
            # Targeted attack PGD.
            if arguments.Config["attack"]["pgd_alpha"] == 'auto':
                alpha = max_eps
            else:
                alpha = float(arguments.Config["attack"]["pgd_alpha"])
            best_deltas, last_deltas = attack_pgd(model, X=x, y=None, epsilon=float("inf"), alpha=alpha, num_classes=arguments.Config["data"]["num_outputs"],
                    attack_iters=arguments.Config["attack"]["pgd_steps"], num_restarts=arguments.Config["attack"]["pgd_restarts"], upper_limit=data_max, lower_limit=data_min,
                    multi_targeted=False, lr_decay=arguments.Config["attack"]["pgd_lr_decay"], target=pidx_list[0], initialization=initialization, early_stop=arguments.Config["attack"]["pgd_early_stop"], GAMA_loss=GAMA_loss)
        else:
            # Untargeted attack PGD.
            if arguments.Config["attack"]["pgd_alpha"] == 'auto':
                alpha = max_eps/4.0
            else:
                alpha = float(arguments.Config["attack"]["pgd_alpha"])
            best_deltas, last_deltas = attack_pgd(model, X=x, y=torch.tensor([y], device=x.device), epsilon=float("inf"), alpha=alpha, num_classes=arguments.Config["data"]["num_outputs"],
                    attack_iters=arguments.Config["attack"]["pgd_steps"], num_restarts=arguments.Config["attack"]["pgd_restarts"], upper_limit=data_max, lower_limit=data_min,
                    multi_targeted=True, lr_decay=arguments.Config["attack"]["pgd_lr_decay"], target=None, initialization=initialization, early_stop=arguments.Config["attack"]["pgd_early_stop"], GAMA_loss=GAMA_loss)

        if x.shape[0] == 1:
            attack_image = torch.max(torch.min(x + best_deltas, data_max), data_min)
            assert (attack_image >= data_min).all()
            assert (attack_image <= data_max).all()
            # assert (attack_image-x).abs().max() <= eps_temp.max(), f"{(attack_image-x).abs().max()} <= {eps_temp.max()}"
            attack_output = model(attack_image).squeeze(0)

            # FIXME (10/02): This is not the best image for each attack target. We should save best attack delta for each target.
            all_targets_attack_image = torch.max(torch.min(x + last_deltas, data_max), data_min)

            attack_label = attack_output.argmax()
            print("pgd prediction:", attack_output)

            # FIXME (10/05): Cleanup.
            if only_target_attack:
                # Targeted attack, must have one label.
                attack_logit = attack_output.data[pidx_list[0]].item()
                attack_output.data[pidx_list[0]] = -float("inf")
                attack_margin = attack_output.max().item() - attack_logit
                print("attack margin", attack_margin)
                if attack_label == pidx_list[0]:
                    assert len(pidx_list) == 1
                    print("targeted pgd succeed, label {}, against label {}".format(y, attack_label))
                    # FIXME (10/05): Please check! attack_image is for one target only.
                    return True, attack_image.detach(), [attack_margin]
                else:
                    print(f"targeted pgd failed, margin {attack_margin}")
                    return False, attack_image.detach(), [attack_margin]
            else:
                # Untargeted attack, any non-groundtruth label is ok.
                groundtruth_logit = attack_output.data[y].item()
                attack_output.data[y] = -float("inf")
                attack_margin = groundtruth_logit - attack_output
                print("attack margin", attack_margin)
                # Untargeted attack, any non-groundtruth label is ok.
                if attack_label != y:
                    print("untargeted pgd succeed, label {}, against label {}".format(y, attack_label))
                    return True, attack_image.detach(), attack_margin.detach().cpu().numpy()
                else:
                    print("untargeted pgd failed")
                    return False, attack_image.detach(), attack_margin.detach().cpu().numpy()
        else:
            # FIXME (10/02): please remove duplicated code!
            attack_images = torch.max(torch.min(x + best_deltas, data_max), data_min)
            attack_output = model(attack_images).squeeze(0)
            # do in batch attack
            attack_label = attack_output.argmax(1)

            if only_target_attack:
                # Targeted attack, must have one label.
                if (attack_label == pidx_list[0]).any():
                    # FIXME (10/02): remove duplicated code.
                    assert len(pidx_list) == 1
                    # print("targeted pgd succeed, label {}, against label {}".format(y, attack_label))
                    attack_logit = attack_output.data[:, pidx_list[0]].clone()
                    attack_output.data[:, pidx_list[0]] = -float("inf")
                    attack_margin = attack_output.max(1).values - attack_logit
                    return True, attack_images.detach(), attack_margin.detach().cpu().numpy()
                else:
                    attack_logit = attack_output.data[:, pidx_list[0]].clone()
                    attack_output.data[:, pidx_list[0]] = -float("inf")
                    attack_margin = attack_output.max(1).values - attack_logit
                    # print(f"targeted pgd failed, margin {attack_margin}")
                    return False, attack_images.detach(), attack_margin.detach().cpu().numpy()
            else:
                raise NotImplementedError
                # TODO support batch
                # Untargeted attack, any non-groundtruth label is ok.
                groundtruth_logit = attack_output.data[y].item()
                attack_output.data[y] = -float("inf")
                attack_margin = groundtruth_logit - attack_output
                # Untargeted attack, any non-groundtruth label is ok.
                if attack_label != y:
                    print("untargeted pgd succeed, label {}, against label {}".format(y, attack_label))
                    return True, attack_images.detach(), attack_margin.detach().cpu().numpy()
                else:
                    print("untargeted pgd failed")
                    return False, attack_images.detach(), attack_margin.detach().cpu().numpy()
    elif "NN4SYS" in dataset:
        # attack for nn4sys, attack the specs line by line
        # vnnlib:
        # [num_spec, num_X, 2(lower, upper)], a y <= b, num_spec
        specs = [vnnlib[i][0] for i in range(len(vnnlib))]
        y_sign = [vnnlib[i][1][0][0] for i in range(len(vnnlib))]
        y_upper = [vnnlib[i][1][0][1] for i in range(len(vnnlib))]

        # attack the specs in batch
        for i, spec in enumerate(specs):
            # print the top 10 specs for debugging.
            print('##### PGD attack: Batch {}, Threshold: {} ######'.format(i, y_upper[i][:10].squeeze() * y_sign[i][:10].squeeze()), y_sign[i][:10].squeeze())
            data_min = torch.Tensor(specs[i][:,:,0]).view(-1, *x.shape[1:]).to(x.device)
            data_max = torch.Tensor(specs[i][:,:,1]).view(-1, *x.shape[1:]).to(x.device)

            x = (data_min + data_max)/2

            eps = ((data_max - data_min)[(data_max - data_min)!=0]).mean().item()
            y = torch.Tensor(y_sign[i]).to(x.device)
            target = torch.Tensor(y_upper[i]).to(x.device)

            if arguments.Config["attack"]["pgd_alpha"] == 'auto':
                alpha = eps/4.0
            else:
                alpha = float(arguments.Config["attack"]["pgd_alpha"])

            best_deltas, last_deltas = attack_pgd(model, X=x, y=y, target=target, epsilon=float("inf"), alpha=alpha, num_classes=arguments.Config["data"]["num_outputs"],
                attack_iters=arguments.Config["attack"]["pgd_steps"], num_restarts=arguments.Config["attack"]["pgd_restarts"], upper_limit=data_max, lower_limit=data_min,
                multi_targeted=False, lr_decay=arguments.Config["attack"]["pgd_lr_decay"], initialization=initialization, early_stop=arguments.Config["attack"]["pgd_early_stop"], GAMA_loss=GAMA_loss, nn4sys=True)


            attack_images = torch.max(torch.min(x + best_deltas, data_max), data_min)
            attack_output = model(attack_images).squeeze(0)
            print("model output: ", model(x).squeeze(0)[:10].squeeze())
            print("attack output: ", attack_output[:10].squeeze())

            if (attack_output.detach().cpu().numpy().squeeze() * y_sign[i].squeeze() <= y_upper[i]).any():
                index = (attack_output.detach().cpu().numpy().squeeze() * y_sign[i].squeeze() <= y_upper[i]).nonzero()
                print("pgd succeed, against upper bound {} with ".format(y_upper[index[0]], attack_output[index[0]]))
                return True, attack_images[index[0]].detach(), attack_output[index[0]].detach().cpu().numpy()

        print("pgd failed")
        return False, attack_images.detach(), attack_output.detach().cpu().numpy()
    else:
        print("pgd attack not supported for dataset", dataset)
        raise NotImplementedError


def auto_attack(model_ori, x, max_eps=None, data_max=None, data_min=None, y=None, vnnlib=None):
    if max_eps is None or type(max_eps) != float:
        eps = (data_max - data_min)/2 if max_eps is None else max_eps
        eps = eps[:,:,0,0]

        standard_eps = eps.mean()
        factor = standard_eps/eps
        std = factor
        mean = torch.zeros(std.shape).view([1, -1, 1, 1]).to(x.device)
        std = (std).view([1, -1, 1, 1]).to(x.device)
        unormalized_x = x * std + mean
    else:
        standard_eps = max_eps
        mean = torch.zeros(std.shape).view([1, -1, 1, 1]).to(x.device)
        std = torch.ones(std.shape).view([1, -1, 1, 1]).to(x.device)
        unormalized_x = x * std + mean

    from autoattack import AutoAttack

    normalized = Normalization(mean, std, model_ori)
    normalized = normalized.to(x.device)
    adversary = AutoAttack(normalized, norm='Linf', eps=standard_eps.item(), version='standard')
    # adversary.attacks_to_run = ['apgd-ce']

    only_target_attack = False # untargeted attack by default

    if vnnlib is not None:
        # Extract attack target from vnnlib.
        for prop_mat, prop_rhs in vnnlib[0][1]:
            if len(prop_rhs) > 1:
                output = model_ori(x).detach().cpu().numpy().flatten()
                print('model output:', output)
                vec = prop_mat.dot(output)
                selected_prop = prop_mat[vec.argmax()]
                y = int(np.where(selected_prop == 1)[0])  # true label, whatever in target attack
                pidx = int(np.where(selected_prop == -1)[0])  # target label
                only_target_attack = True
            else:
                assert len(prop_mat) == 1
                y = np.where(prop_mat[0] == 1)[0]
                if len(y) != 0:
                    y = int(y)
                else:
                    y = None
                pidx = int(np.where(prop_mat[0] == -1)[0])  # target label
            if pidx == y:
                raise NotImplementedError
    assert not only_target_attack # untargeted attack by default

    attack_images = adversary.run_standard_evaluation(unormalized_x, torch.Tensor([y]).long().to(x.device), bs=1)
    attack_images = torch.max(torch.min((attack_images - mean)/std, data_max), data_min)
    if max_eps is not None:
        attack_images = torch.max(torch.min(attack_images, x + max_eps), x - max_eps)
    attack_ret = model_ori(attack_images)
    attack_margin = attack_ret[:,y] - attack_ret[0]
    attack_margin[y] = float("inf")
    attack_label = attack_ret.argmax().item()
    attack_ret = not attack_ret.argmax().item() ==y

    attack_images = attack_images.detach()
    attack_margin = attack_margin.detach().cpu().numpy()

    if attack_ret:
        print("untargeted auto attack succeed, label {}, against label {}".format(y, attack_label))
    else:
        print("untargeted auto attack failed")

    return attack_ret, attack_images, attack_margin


def attack_after_crown(lb, vnnlib, model_ori, x, decision_thresh):
    # Run adversarial attack on those specs that cannot be verified by CROWN.
    crown_filtered_constraints = np.zeros(len(lb[-1]))
    for i in range(len(lb[-1])):
        if isinstance(decision_thresh, torch.Tensor):
            assert decision_thresh.ndim == 2
            assert decision_thresh.shape[0] == 1, "Batch size should be 1 for attack after CROWN."
            if lb[-1][i].item() > decision_thresh[0][i].item():
                crown_filtered_constraints[i] = True
        else:
            # Threshold is a float.
            if lb[-1][i].item() > decision_thresh:
                crown_filtered_constraints[i] = True

    _, verified_success, attack_images, _, _ = attack(
        model_ori, x, [vnnlib],
        verified_status="unknown", verified_success=False,
        crown_filtered_constraints=crown_filtered_constraints)

    return verified_success, attack_images

