import sys, argparse, logging
import random
import pickle
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras.preprocessing.image import save_img
import numpy as np
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.models import load_model
from art.utils import load_mnist, load_cifar10
from tensorflow.keras.datasets import cifar100
from art.attacks.evasion import FastGradientMethod, ProjectedGradientDescent, CarliniL2Method, CarliniLInfMethod, DeepFool, SaliencyMapMethod, AutoProjectedGradientDescent, AutoAttack, SquareAttack
from art.estimators.classification import KerasClassifier, TensorFlowV2Classifier
from sklearn.metrics import accuracy_score
from scipy.stats import entropy
from scipy.special import softmax
from test_script_bayesian_settings_finder import find_MuStd, uncertaintyDiscriminator, MC_predictions, MC_Harvest
from Laplace_approximation_Diagonal2 import LaplaceApproximation
from conversion import insert_layer_nonseq
from mcdropout import MCDropout
from pipeline_test2 import *
#data_dir = "/local/Data/tiny-imagenet-200"
#from def_settings import *
# Global Variable section
# Pass 7 arguments:
# dataset, model_path, attacks, NOTA_def, NOTA_attk, MC_dropout, LaPlace

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  # Restrict TensorFlow to only allocate 1GB of memory on the first GPU
  try:
    tf.config.set_visible_devices(gpus[0], 'GPU')
    #logical_gpus = tf.config.list_logical_devices('GPU')
    tf.config.set_logical_device_configuration(
        gpus[0],
        [tf.config.LogicalDeviceConfiguration(memory_limit=38912)])
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Virtual devices must be set before GPUs have been initialized
    print('THERE WAS A BIG ERROR')
    print(e)



# Function section
# use only after attack constructor
def execute_attack(attack, X_adv_src, y_adv_src, stochastic, NOTA_attk:bool=False, n_mc:int=100, JSMA:bool=False):
    if stochastic:
        if JSMA:
            X_adv = attack.generate_stochastic(x=X_adv_src, NOTA=NOTA_attk, n_mc=n_mc)
        else:
            X_adv = attack.generate_stochastic(x=X_adv_src, y=y_adv_src, NOTA=NOTA_attk, n_mc=n_mc)
    else:
        if JSMA:
            X_adv = attack.generate(x=X_adv_src, NOTA=NOTA_attk)
        else: 
            X_adv = attack.generate(x=X_adv_src, y=y_adv_src, NOTA=NOTA_attk)
    return X_adv

# Takes model and adversarial examples with their true labels and tests whether the attack successfully changes the 
# model's output to a to a non-NOTA, non-true class.
def evaluate_attack(classifier, X_adv, y_adv_src, num_class=11, n_mc=30, nota:bool=True, MC_dropout:bool=False,
                    LaPlace:bool=False):
    if MC_dropout or LaPlace:
        # Write a loop here that performs predictions on x_adv 'n_mc' times, where n_mc is the number of desired monte 
        # carlo predictions for each adversarial example.
        activity = "Stochastic Model "
        y_adv_pred, ____, ____, y_sec = MC_predictions(classifier, 
                                                              X_adv, 
                                                              m=n_mc, 
                                                              num_class=num_class,
                                                              nota=nota,
                                                              sftmx=True,
                                                              batch_size=len(X_adv))
            
    else:
        activity = "Deterministic Model "
        # Test Base Deterministic Accuracy 
        y_adv_raw_det_pred = classifier.predict(X_adv, training_mode=False)
        y_adv_det_pred_sft = softmax(y_adv_raw_det_pred, axis=-1)
        y_adv_pred = np.argmax(y_adv_det_pred_sft, axis=1)
        if nota:
            y_adv_det_pred_sft[:,(num_class-1)] = 0.0 
            # Need to provide generalized length minus 1 (the NOTA class) here for the 10
            y_sec = np.argmax(y_adv_det_pred_sft, axis=1)
        else:
            y_sec = np.zeros(y_adv_src.shape)
    print(activity)
    return y_adv_pred, y_sec

def print_results(y_adv_pred, y_adv_src, y_sec, num_class:int=11, nota:bool=False, MC_dropout:bool=False, 
                  LaPlace:bool=False, verbose:bool=True):
    success = 0.0
    acc = 0.0
    if MC_dropout or LaPlace:
        type_ = "Stoch_"
    else:
        type_ = "Det_"
    for j in range(0,y_adv_src.shape[0]):
        if verbose:
            print("Actual: ", y_adv_src[j], ' | ', type_, 'Prediction: ', y_adv_pred[j])
        if nota:
            if (y_adv_pred[j] != y_adv_src[j]) and (y_adv_pred[j] != num_class-1):
                success += 1.0
                if verbose:
                    print("-->counted")
        else: 
            if (y_adv_pred[j] != y_adv_src[j]):
                success += 1.0
                if verbose:
                    print("-->counted")
        if (y_adv_pred[j] == y_adv_src[j]):
            acc += 1.0
        if nota and (y_adv_pred[j] == num_class-1):
            if verbose:
                print("2nd Highest: ", str(y_sec[j]))
    att_succ = success/y_adv_src.shape[0]
    reg_accuracy = acc/y_adv_src.shape[0]
    print("Robust Class Accuracy: ", str(reg_accuracy))
    print(type_, "attack_success: ", str(att_succ))
    print("*******************************************************")
    
    return 

def save_adversarials(det_adv_expls, stoch_adv_expls, y_true_labels, attk_list, wbx_test_model, mod_type, NOTA_attk):
    attk_string = "-".join(element for element in attk_list)
    file2 = "./0_data/" + wbx_test_model[15:] + mod_type +"NOTAattk="+str(NOTA_attk) + attk_string + 'DetWB_linf2'
    file3 = "./0_data/" + wbx_test_model[15:] + mod_type +"NOTAattk="+str(NOTA_attk) + attk_string + 'EoT_linf2'
    fp2 = open(file2,'ab') 
    fp3 = open(file3, 'ab') 
    det_set = [det_adv_expls, y_true_labels]
    stoch_set = [stoch_adv_expls, y_true_labels]
    pickle.dump(det_set,fp2,protocol=pickle.HIGHEST_PROTOCOL)
    pickle.dump(stoch_set,fp3,protocol=pickle.HIGHEST_PROTOCOL)
    fp2.close()
    fp3.close()
    print("Deterministic Whitebox Examples saved here: ", file2)
    print("Stochastic Whitebox Expectation over Tranformation Examples saved here: ", file3)
    return

def attack_and_report(classifiers, det_attk, attack, X_advSrc, y_advSrc, NOTA_attk:bool=False, NOTA_def:bool=False, 
                      num_class:int=11, n_mc:int=30, MC_dropout:bool=False, LaPlace:bool=False, JSMA:bool=False, 
                      verbose:bool=True):
    advExpDet = []
    advExpSto = []
    stochastic = MC_dropout or LaPlace
    # Deterministic Whitebox Transfer necessitates attack with stochastic=False
    X_adv = execute_attack(det_attk, X_advSrc, y_advSrc, stochastic=False, NOTA_attk=NOTA_attk, n_mc=n_mc, JSMA=JSMA)
    y_adv_pred, y_sec = evaluate_attack(classifiers[0], X_adv, y_advSrc, num_class=num_class, n_mc=n_mc, nota=NOTA_def, 
                                        MC_dropout=MC_dropout, LaPlace=LaPlace)
    print("Deterministic WB Transfer")
    print_results(y_adv_pred, y_advSrc, y_sec, num_class=num_class, nota=NOTA_def, MC_dropout=MC_dropout, 
                  LaPlace=LaPlace, verbose=verbose)
    advExpDet.extend(X_adv)
    # Stochastic Attack 
    if stochastic:
        X_adv = execute_attack(attack, X_advSrc, y_advSrc, stochastic=stochastic, NOTA_attk=NOTA_attk, n_mc=n_mc, JSMA=JSMA)
        y_adv_pred, y_sec = evaluate_attack(classifiers[0], X_adv, y_advSrc, num_class=num_class, n_mc=n_mc, nota=NOTA_def, 
                                            MC_dropout=MC_dropout, LaPlace=LaPlace)
        print("Stochastic WB Attack using expectation of the gradient")
        print_results(y_adv_pred, y_advSrc, y_sec, num_class=num_class, nota=NOTA_def, 
                      MC_dropout=MC_dropout, LaPlace=LaPlace, verbose=verbose)
        advExpSto.extend(X_adv)
    return advExpDet, advExpSto

    
def print_norm_dist(X_adv_src, X_adv):
    print("L_inf Norm", np.max(np.reshape(np.abs(X_adv_src-X_adv),[X_adv_src.shape[0],-1]),axis=1))
    print("L_2 Norm", np.sqrt(np.sum(np.reshape(np.square(X_adv_src-X_adv),[X_adv_src.shape[0],-1]),axis=1)))
    return

def allSet_eval(X_val, dataset, model):
    if dataset == 'TI':
        max_test = 500
        y_pred = []
        for j in range(0,X_val.shape[0]//max_test):
            temp = model(X_val[j*max_test:((j+1)*max_test)])
            y_pred.extend(temp)
    else:
        y_pred = model(X_val)
    y_pred = np.argmax(y_pred, axis=1)
    return y_pred

def mcdropout_layer_factory(p):
    return MCDropout(rate=p, name='mcdropout')

def nota_logit_activation(x):
    # offset is the factor that is multiplied by the tanh of how much the NOTA logit is above to the threshold.
    offset = 20
    threshold = 1.0
    nota_mask = np.zeros(x.shape)
    nota_mask[:,-1]+=1
    nota_mask = tf.convert_to_tensor(nota_mask)
    return x + offset*tf.nn.relu(tf.math.tanh((x*nota_mask-nota_mask*threshold)*10))#nn.relu(x*nota_mask-nota_mask*threshold)

#print("Made it to Main section.")
# Main Section
def main():
    dataset = sys.argv[1]#'cifar10'# cifar10, cifar100, TI
    wbx_test_model = sys.argv[2] # Path to white box model from this directory
    DET_ADV_AVAIL = False
    ADV_SET = "./0_data/DAX_LAPMCD__cifar10_TSNAP2_AMMod_PGD_CE_DLR_Alt_double_term_6_1_cycle__ASR=0.0_VAL_acc=0.914_26APR_notaAttk_VAL_MI_100.pkl"
    max_iter = 100 #100
    batch_size = 128 # See line 294 for TI batch_size
    ITER2 = [0,20]
    # Set Attack List
    attack = sys.argv[3] 
    if attack == 'all':
        attk_list = ['AutoPGD','AutoAttack','SquareAttack','CWL2','CWLinf','DeepFool']
    elif (attack == 'AutoPGD') or (attack == 'AutoAttack') or (attack == 'SquareAttack') or (attack == 'CWL2') or (attack == 'CWLinf') or (attack == 'DeepFool') or (attack == 'JSMA'): 
        attk_list = [attack]
    elif attack == 'subset':
        attk_list = ['CWLinf','CWL2']
    elif attack == 'subset2':
        attk_list = ['SquareAttack','DeepFool','CWL2','CWLinf']
    elif attack == 'subset3':
        attk_list = ['DeepFool','CWLinf']
    elif attack == 'subset4':
        attk_list = ['SquareAttack','DeepFool','AutoPGD','CWLinf']
    elif attack == 'Auto_group':
        attk_list = ['AutoPGD','AutoAttack']
    elif attack == 'targ_AA':
        attk_list = ['targ_AA']
    else:
        print("Failed: Please specify either 'all' or any one of these attacks: 'AutoPGD', 'AutoAttack' ,'SquareAttack' ,'CWL2' ,'CWLinf' ,'DeepFool', 'JSMA'")
        exit()
    
    #### Remaining Settings ####
    NOTA_def = sys.argv[4] == 'True'
    NOTA_attk = sys.argv[5] == 'True'
    MC_dropout = sys.argv[6] == 'True'
    new_LaPlace = False
    LaPlace = sys.argv[7] #== 'True'
    if (LaPlace == 'new') or (LaPlace == 'True'):
        LaPlace = True
        new_LaPlace = True
        laPath = sys.argv[8]
    elif LaPlace == 'old':
        laPath = sys.argv[8]
        LaPlace =  True
    else: 
        LaPlace = False
        laPath = sys.argv[8]
    
    batch_size_runs = sys.argv[9]
    if (batch_size_runs == 'True'):
        batch_size_runs = True
    else:
        batch_size_runs = False
    #### total adversarials at completion
    
    
    sftmx = True
    verbose=False
    print("***************************************************************")
    #print("First In Parallel run on the same GPU for Efficiency Test.")
    print("dataset= ",dataset, " model_path= ", wbx_test_model, " attack= ", attack, " NOTA_def: ", NOTA_def, " NOTA_attk: ", 
         NOTA_attk, " MC_dropout: ", MC_dropout, " LaPlace: ", LaPlace, " LA_weights: ", laPath)
    
    if MC_dropout and LaPlace:
        n_mc = 50
    else:
        n_mc = 30
    
    #### Initial Model Loading ####
    # Load Deterministic Models
    model = keras.models.load_model(wbx_test_model)
    print(wbx_test_model)
    print("Unchanged Model Summary (Before Dropout replacement, if used):")
    model.summary(expand_nested=True, show_trainable=True)
    orig_det_model = keras.models.load_model(wbx_test_model)
    
    ################## Load Dataset & Data Preparation #######################
        
    # Load Dataset
    if dataset == 'cifar10':
        # Load CIFAR-10
        (X_train, y_train), (X_test, y_test), min_pixel_value, max_pixel_value = load_cifar10()
        y_train = np.argmax(y_train, axis=1)
        y_test = np.argmax(y_test, axis=1)
        X_train = X_train.reshape(X_train.shape[0],3072)
        if NOTA_def:
            num_class = 11
        else:
            num_class = 10
        input_shape = [32,32,3]
        total_Advs = 500

    
    elif dataset == 'cifar100':
        # Load CIFAR-100
        (X_train, y_train), (X_test, y_test) = cifar100.load_data(label_mode='fine')
        X_train = X_train / 255.0
        X_test = X_test / 255.0
        print("y_train examples", y_train.shape)
        #y_train = np.argmax(y_train, axis=1)
        y_train = np.squeeze(y_train)
        #y_test = np.argmax(y_test, axis=1)
        y_test = np.squeeze(y_test)
        print("y_train examples", y_train.shape)
        X_train = X_train.reshape(X_train.shape[0],3072)
        if NOTA_def:
            num_class = 101
        else:
            num_class =100
        input_shape = [32,32,3]
        total_Advs = 500

    elif dataset == 'TI':
        if NOTA_def:
            num_class = 201
        else:
            num_class = 200
        input_shape = [64,64,3]
        batch_size = 32
        total_Advs = 1000
    
    else:
        print("Please choose a dataset: 'cifar10', 'cifar100', 'TI'")
        exit()
    
    if (dataset != 'TI'):         
        X_train = X_train.reshape(X_train.shape[0],32,32,3)
        X_test = X_test.reshape(X_test.shape[0],32,32,3)
        y_train = y_train.astype(np.float64)
        y_test = y_test.astype(np.float64)
        X_val = X_train[:2000]
        y_val = y_train[:2000]
        X_train = X_train[2000:]
        y_train = y_train[2000:]
    
        datagen = ImageDataGenerator(
            rotation_range=15, 
            width_shift_range=0.1,
            height_shift_range=0.1,
            horizontal_flip=True,
            #brightness_range=[0.5,1.5],
            #zoom_range=[0.5,1.0],
        )
        dataset_clean = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(48000).batch(32)
        global iterator 
        iterator = iter(dataset_clean)
        n_steps = X_train.shape[0] // 32 #batch_size
        training_dataset = datagen.flow(X_train, y_train, 32)
        num_test_samples = 100
        
    else: # dataset == 'TI'
        n_steps = 96000//32
        testset_size = 10000
        trainset_size = 90000
        valset_size = 4000
        class_mapping = get_class_mapping(data_dir)
        class_names = get_class_names(data_dir)
        val_split = get_val_split(data_dir)
        train_filenames, eval_filenames, val_filenames = get_train_val_filenames(data_dir)
        train_labels, eval_labels, val_labels = get_train_eval_val_labels(data_dir, val_split, train_filenames, 
                                                        eval_filenames, val_filenames)
        #training_dataset = get_train_dataset(data_dir, train_filenames, train_labels, 
        #                                     class_mapping, int(1000))
        validation_dataset = get_val_dataset(data_dir, eval_filenames, eval_labels, 
                                             class_mapping, valset_size)
        test_dataset = get_val_dataset(data_dir, val_filenames, val_labels, 
                                             class_mapping, testset_size)
        # New Clean training dataset and iterable are created to allow padding to be created 
        # from non-augmented examples. Test iterables necessary since entire test set does not fit in 
        # memory
        clean_training_dataset_long = get_clean_train_dataset(data_dir, train_filenames,
                                                              train_labels, class_mapping, trainset_size)
        
        test_iter = test_dataset.__iter__()
        val_iter = validation_dataset.__iter__() 
        X_test, y_test = next(test_iter)
        clean_train_long_iter = clean_training_dataset_long.__iter__()
        X_val, y_val = next(val_iter)
        X_train, y_train = next(clean_train_long_iter)
        # To provide early stopping function 1000 
        print("Data is loaded.")
        num_test_samples = 200
        X_test = X_test.numpy()
        y_test.numpy()
    
    # Set to create adversarial examples to test robustness from
    #X_advSrc = X_test[100:(100+num_test_samples)]    
    #y_advSrc = y_test[100:(100+num_test_samples)]
    X_advSrc_val = X_val[100:(100+num_test_samples)]    
    y_advSrc_val = y_val[100:(100+num_test_samples)]
    
    loss = keras.losses.CategoricalCrossentropy()
    
    l = input_shape[0]
    w = input_shape[1]
    d = input_shape[2]
    full_dim = l*w*d
    
    ###################################################################################
    
    ################ Prepare Models and Mature Stochastic Models if Applicable ########
    # MCDropout Model Conversion
    tag = "_"
    if LaPlace:
        tag = tag+"LAP"
    if MC_dropout:
        # Insert conversion process here. Strip normal dropout layers and replace with MCDropout Layers. Difference is that
        # MCDropout layers behave as if Training = True no matter what training is set to.
        MC_model = insert_layer_nonseq(model, 'dropout', mcdropout_layer_factory, insert_layer_name="mcdrop", 
                                       position='replace')
        # Fix possible problems with new model
        MC_ver = './models/' + wbx_test_model[17:] + '_MC_wb_ver.h5'
        MC_model.save(MC_ver)
        model = load_model(MC_ver, custom_objects={'MCDropout': MCDropout})
        
        print("Dropout Layers have been replaced with MC Dropout layers, causing behavior to always enable dropout, ",
              "regardless of Training boolean. This allows Training boolean to be kept false so that batch normalized ",
              "layers will not be affected. In order to use a model without dropout enabled, use the original model.")
        model.summary(expand_nested=True, show_trainable=True)
        tag = tag+"MCD_"
           
    # If attacks are CW, AutoAttack, AutoPGD, DF, or JSMA strip off the softmax layer, else leave it on.
    orig_det_model.layers[-1].activation=None
    model.layers[-1].activation=None
        
    # Original Deterministic Model to be used Vanilla deterministic and basis for building normal Laplace Approximation Model.
    orig_det_model.compile(loss=keras.losses.SparseCategoricalCrossentropy(), optimizer="adam", metrics=["accuracy"])
    model.compile(loss=keras.losses.SparseCategoricalCrossentropy(), optimizer="adam", metrics=["accuracy"])

    ################## Prepare for LaPlace if needed, create low confidence AEs ############################
    # Create Deterministic Adversarial model and examples for training Laplace Fit or load from directory 
    classifier_det = TensorFlowV2Classifier(model=orig_det_model, clip_values=(0,1), nb_classes=num_class, 
                                            input_shape=input_shape, loss_object=loss)
    if LaPlace or MC_dropout:
        if (DET_ADV_AVAIL and new_LaPlace):
            # Original approach used base model deterministic adversarial examples to find the right lamda
            # Now also loading low confidence adversarials from both stochastic and deterministic to find lamda
            (X_det_adv, y_det_adv) = pickle.load(open(ADV_SET,'rb'))
                    #"./0_data/Det_advEx2_BP_C10_dropout_VALIDATION_oldAttk_attack_success=0.15_VAL_acc=0.922.pkl",'rb'))
            #all_aes = pickle.load(open(STOCH_SET,'rb'))
            X_adv_lam = X_det_adv.copy()
            y_adv_lam = y_det_adv.copy()
        elif(new_LaPlace):
            X_adv_lam = []
            y_adv_lam = []
            #for i in ITER2:
                #det_attk = CarliniL2Method(classifier=classifier_det, confidence=i, initial_const=1, max_iter=max_iter,
                #                           batch_size=batch_size, targeted=False)
                #trsfr_x_adv = det_attk.generate(x=X_advSrc_val, y=y_advSrc_val, NOTA=NOTA_attk)
                #X_adv_lam.extend(trsfr_x_adv)
                #y_adv_lam.extend(y_advSrc_val)
            file = wbx_test_model[17:] + "_26APR"
            if NOTA_attk:
                file = file + "_notaAttk_VAL"
            else:
                file = file + "_oldAttk_VAL"
            file = file + "_MI_"+ str(max_iter)
            det_AE = "./0_data/DAX" + tag + file + ".pkl"
            fp2 = open(det_AE,'ab')  
            pickle.dump((X_adv_lam,y_adv_lam),fp2,protocol=pickle.HIGHEST_PROTOCOL)
            fp2.close()
            X_det_adv = X_adv_lam.copy()
            y_det_adv = y_adv_lam.copy()
            print("Deterministic Adversarial examples saved here: ", det_AE)
    #orig_det_model.layers[-1].activation=tf.keras.activations.softmax
    
    # LaPlace Code Here
    if LaPlace:
        mod_type = "LaPlace"
        #print("Weight Decay is: ", weight_decay)
        la = LaplaceApproximation(model)
        if new_LaPlace:
            X_adv_lam = np.asarray(X_adv_lam)
            la.fit((X_train, y_train), (X_val,y_val), (X_adv_lam,y_adv_lam), samples=X_train.shape[0], mc_drop=n_mc, 
                   trgt_perc_trAcc=0.96, batch_size=2000, nota=NOTA_def,sftmx=sftmx,DrpOut=MC_dropout)
            la.save('./models2/'+ wbx_test_model[17:] + tag)
            del X_train
            del y_train
            print('successfully saved your LaPlace Model')
        else:
            del X_train
            del y_train
            la.load(laPath)
            print("LaPlace Model loaded, LaPlace Model Weight Decay set to: ", la.weight_decay, flush=True)
        classifier = TensorFlowV2Classifier(model=la, clip_values=(0,1), nb_classes=num_class, 
                                            input_shape=input_shape, loss_object=loss)
        mod_type = mod_type + "_lamda_"+str(la.weight_decay)
        if MC_dropout:
            print("Trained as ", mod_type, " with MC Dropout")
            mod_type = mod_type + "_MC_Dropout"
        else:
            print("Trained as ", mod_type, " and No MC Dropout")
            mod_type = mod_type + "_NO_MCD"
    
    else:
        mod_type = "No_LaPlace"
        del X_train
        del y_train
        if MC_dropout:
            classifier = TensorFlowV2Classifier(model=model, clip_values=(0,1), nb_classes=num_class, 
                                                input_shape=input_shape, loss_object=loss)
            print("Trained as MC Dropout and " , mod_type)
            mod_type = "MC_Dropout" + mod_type
            
            # Overall Deterministic ASR for the MC_Dropout Model
            ##y_MC_preds = np.zeros((len(y_det_adv),num_class), dtype=float) 
            ##for m in range(n_mc):
            ##    y_preds = classifier.predict(X_det_adv)
            ##    y_MC_preds += y_preds
            #y_pred = y_MC_preds / mc_drop
            ##y_pred = np.argmax(y_MC_preds, axis=1)
            # Logic for temporary ASR here
            ##success = 0.0
            ##for j in range(0,len(y_det_adv)):
                #print(y_pred[j], y_advSrc[j])
            ##    if NOTA_def:
            ##        if y_pred[j] != y_det_adv[j] and (y_pred[j] != num_class - 1):
            ##            success += 1.0
            ##    else: 
            ##        if y_pred[j] != y_det_adv[j]:
            ##            success += 1.0
            ##tmp_asr = success/len(y_det_adv)                
            ##print("MC_Dropout ASR against Deterministic WB Transfer = ", tmp_asr)
            
        else:
            classifier_det = TensorFlowV2Classifier(model=orig_det_model, clip_values=(0,1), nb_classes=num_class,
                                                    input_shape=input_shape, loss_object=loss)
            classifier = TensorFlowV2Classifier(model=orig_det_model, clip_values=(0,1), nb_classes=num_class,
                                                    input_shape=input_shape, loss_object=loss)
            print("Trained as ", mod_type, " and No MC Dropout")
            mod_type = mod_type + "_NO_MCD"
        
        #bb_classifier = TensorFlowV2Classifier(model=bb_model, clip_values=(0,1), nb_classes=num_class, 
        #                                       input_shape=input_shape, loss_object=loss)
        #bb_classifier_det = bb_classifier  # Distinction will be made in using training=True or False or in the model used.
    
    mod_type = mod_type + "_mxIt_" + str(max_iter)
    classifiers = [classifier, classifier_det] #, bb_classifier]

    
    # ***********************************
    # Test Base Deterministic Accuracy
    y_pred = allSet_eval(X_test, dataset, orig_det_model)
    det_accuracy = accuracy_score(y_test, y_pred)
    print("Original Deterministic Model Accuracy:",det_accuracy)
    
    # ***********************************
    
    
    name = wbx_test_model[17:]
    name = "03_tests/" + mod_type + name + ".txt"
    
    #bbx = bbx_test_model[12:]
    print("WBX Model: ", wbx_test_model[17:])
    
    stochastic = LaPlace or MC_dropout
    
    if stochastic:
        y_pred_ben, epist_unc_ben, al_unc_ben, y_sec_ben = MC_predictions(classifiers[0], 
                                                                          X_val, 
                                                                          m=n_mc,
                                                                          num_class=num_class,
                                                                          nota=NOTA_def,
                                                                          sftmx=sftmx,
                                                                          batch_size=500
                                                                          )
        y_pred_ben_t, epist_unc_ben_t, al_unc_ben_t, y_sec_ben_t= MC_predictions(classifiers[0], 
                                                                                 X_test, 
                                                                                 m=n_mc,
                                                                                 num_class=num_class,
                                                                                 nota=NOTA_def,
                                                                                 sftmx=sftmx,
                                                                                 batch_size=500
                                                                                )
        print("Stochastic Benign Accuracy(VAL):", accuracy_score(y_val, y_pred_ben))
        #print("Stochastic Benign Experimental Accuracy(VAL):", accuracy_score(y_val, y_sec_ben))
        print("Stochastic Benign Accuracy(Test):", accuracy_score(y_test, y_pred_ben_t))
        #print("Stochastic Benign Experimental Accuracy(Test):", accuracy_score(y_test, y_sec_ben_t))
            
    else:
        print("No Stochastic Approaches tested see above deterministic accuracy.")

    stoc_adv = []
    det_adv = []
    cnt = 0
    convert = 1 # factor needed to adapt num_test_samples to harvest previously created adversarial examples
    
    norms = [2] #2, 'inf'
    confs = [0] #[0,20]
    gammas = [0.2] #[0.1,0.2]
    DF_eps = [0.5]#[0.1, 0.2, 0.4, 0.8, 1.6]
    
    autoAttk_targ = [False]#[False, True]
    losses = ['anti_NOTA'] #anti_NOTA']#'cross_entropy']#,'difference_logits_ratio']
    # Iterate through Attacks
    det_adv_expls = []
    stoch_adv_expls = []
    
    print("Model is NOTA defended? ", NOTA_def, " Attacks are NOTA-aware? ", NOTA_attk)
    for attk in attk_list:
        print("#################")
        print("Attack: ",str(attk))

        # Reduced batch size necessary for memory constraints in stochastic CWL2 attacks 
        eff_batch_size = batch_size//2
        if batch_size_runs:
            num_groups = int(total_Advs/(eff_batch_size))
        else:
            num_groups = int(total_Advs/num_test_samples)

        # Loop for full set of Adversarial Examples
        print("There will be ", num_groups, " groups of Adversarial examples with ", num_test_samples, " in each.")

        for batch in range(0, num_groups):
            if batch_size_runs:
                X_advSrc = X_test[batch*eff_batch_size:(batch*eff_batch_size+eff_batch_size)] 
                y_advSrc = y_test[batch*eff_batch_size:(batch*eff_batch_size+eff_batch_size)]
            else:
                X_advSrc = X_test[(300+batch*num_test_samples):(300+batch*num_test_samples+num_test_samples)] 
                y_advSrc = y_test[(300+batch*num_test_samples):(300+batch*num_test_samples+num_test_samples)]
            print("########### Group ", batch+1, " Adversarial Examples #############")
            if (attk) == 'targ_AA':
                attk = 'AutoAttack'
                autoAttk_targ = [True]
                norms = [2] # Norms for targeted go here!!!
            if (attk == 'AutoPGD') or (attk == 'AutoAttack') or (attk == 'SquareAttack'):
                for norm in norms:
                    if norm == 'inf':
                        eps = 8/255
                    elif norm == 2:
                        eps = 0.5
                    if attk == 'AutoPGD':
                        for loss_type in losses:
                            attack = AutoProjectedGradientDescent(estimator=classifiers[0], norm=norm, eps=eps, targeted=False, 
                                                                  batch_size=batch_size, loss_type=loss_type, NOTA=NOTA_attk, 
                                                                  verbose=verbose)
                            det_attk = AutoProjectedGradientDescent(estimator=classifiers[1], norm=norm, eps=eps, targeted=False,
                                                                    batch_size=batch_size, loss_type = loss_type, NOTA=NOTA_attk, 
                                                                    verbose=verbose)
                            print("Attack: ", attk, " Loss Type: ", loss_type, " Norm: ", str(norm), " Max Dist: ", eps)
                            advExpDet, advExpSto = attack_and_report(classifiers, det_attk, attack, X_advSrc, y_advSrc, 
                                              NOTA_attk=NOTA_attk, NOTA_def=NOTA_def, num_class=num_class, n_mc=n_mc, 
                                              MC_dropout=MC_dropout, LaPlace=LaPlace)
                            det_adv_expls.extend(advExpDet)
                            stoch_adv_expls.extend(advExpSto)

                    elif attk == 'AutoAttack':
                        for targ in autoAttk_targ:
                            for loss_type in losses:
                                attack = AutoAttack(estimator=classifiers[0], norm=norm, eps=eps, batch_size=batch_size, 
                                                    targeted=targ, NOTA=NOTA_attk, CE="loss_type")
                                det_attk = AutoAttack(estimator=classifiers[1], norm=norm, eps=eps, batch_size=batch_size, 
                                                    targeted=targ, NOTA=NOTA_attk, CE="loss_type")
                                print("Attack: ", attk, " Targeted: ", targ," Loss Type: ", loss_type, " Norm: ", str(norm), 
                                      " Max Dist: ", eps)
                                advExpDet, advExpSto = attack_and_report(classifiers, det_attk, attack, X_advSrc, y_advSrc,
                                                        NOTA_attk=NOTA_attk, NOTA_def=NOTA_def, num_class=num_class, n_mc=n_mc, 
                                                        MC_dropout=MC_dropout, LaPlace=LaPlace)
                                det_adv_expls.extend(advExpDet)
                                stoch_adv_expls.extend(advExpSto)


                    elif attk == 'SquareAttack':
                        #if norm == 2:
                        #    pass
                        #else:
                        attack = SquareAttack(estimator=classifiers[0],norm=norm, eps=eps, max_iter=max_iter, 
                                              NOTA=NOTA_attk)
                        det_attk = SquareAttack(estimator=classifiers[1],norm=norm, eps=eps, max_iter=max_iter, 
                                                NOTA=NOTA_attk)
                        print("Attack: ", attk, " Norm: ", str(norm), " Max Dist: ", eps)
                        advExpDet, advExpSto = attack_and_report(classifiers, det_attk, attack, X_advSrc, y_advSrc, 
                                             NOTA_attk=NOTA_attk, NOTA_def=NOTA_def, num_class=num_class, n_mc=n_mc, 
                                             MC_dropout=MC_dropout, LaPlace=LaPlace)
                        det_adv_expls.extend(advExpDet)
                        stoch_adv_expls.extend(advExpSto)

                    #elif attk == 'PGD':    
                    #    attack = ProjectedGradientDescent(estimator=classifier[0], eps=eps, norm=norm, targeted=targeted, 
                    #                                      num_random_init=1)
                    #    det_attk = ProjectedGradientDescent(estimator=classifier[1], eps=eps, norm=norm, targeted=targeted,
                    #                                        num_random_init=1)
                    #    print("Attack: ", attk, " Norm: ", str(norm), " Dist: ", eps)
                    #    attack_and_report(classifiers, det_attk, attack, X_advSrc, y_advSrc, NOTA_attk:bool=False,
                    #                      NOTA_def:bool=False, num_class:int=num_class, n_mc:int=n_mc, MC_dropout:bool=False,
                    #                      LaPlace:bool=False)
            elif (attk == 'CWL2') or (attk == 'CWLinf'):
                for conf in confs:
                    if attk == 'CWL2':
                        eff_batch_size = batch_size//2
                        attack = CarliniL2Method(classifier=classifiers[0], confidence=conf, initial_const=1, max_iter=max_iter,
                                                 targeted=False, batch_size=eff_batch_size)
                        det_attk = CarliniL2Method(classifier=classifiers[1], confidence=conf, initial_const=1, 
                                                   max_iter=max_iter, targeted=False, batch_size=eff_batch_size)
                        print("Attack: ", attk, " Confidence: ", conf, " Max Iter: ", max_iter)
                        advExpDet, advExpSto = attack_and_report(classifiers, det_attk, attack, X_advSrc, y_advSrc, 
                                          NOTA_attk=NOTA_attk, NOTA_def=NOTA_def, num_class=num_class, n_mc=n_mc, 
                                          MC_dropout=MC_dropout, LaPlace=LaPlace)
                        det_adv_expls.extend(advExpDet)
                        stoch_adv_expls.extend(advExpSto)

                    elif attk == 'CWLinf':
                        attack = CarliniLInfMethod(classifier=classifiers[0], confidence=conf, max_iter=math.ceil(max_iter/4), 
                                                   targeted=False, batch_size=batch_size//2)
                        det_attk = CarliniLInfMethod(classifier=classifiers[1], confidence=conf, max_iter=max_iter,
                                                     targeted=False,
                                                    batch_size=batch_size//2)
                        print("Attack: ", attk, " Confidence: ", conf, " Max Iter: ", max_iter)
                        advExpDet, advExpSto = attack_and_report(classifiers, det_attk, attack, X_advSrc, y_advSrc, 
                                          NOTA_attk=NOTA_attk, NOTA_def=NOTA_def, num_class=num_class, n_mc=math.ceil(n_mc/2), 
                                          MC_dropout=MC_dropout, LaPlace=LaPlace)
                        det_adv_expls.extend(advExpDet)
                        stoch_adv_expls.extend(advExpSto)

                    else:
                        print("Something went wrong.")
                        break

            elif attk == 'DeepFool':
                for eps in DF_eps:
                    attack = DeepFool(classifier=classifiers[0], max_iter=max_iter, epsilon=eps, batch_size=batch_size)
                    det_attk = DeepFool(classifier=classifiers[1], max_iter=max_iter, epsilon=eps, batch_size=batch_size)
                    print("Attack: ", attk, " Max Iter: ", str(max_iter), " Epsilon: ", eps)
                    advExpDet, advExpSto = attack_and_report(classifiers, det_attk, attack, X_advSrc, y_advSrc, 
                                                             NOTA_attk=NOTA_attk, 
                                               NOTA_def=NOTA_def, num_class=num_class, n_mc=n_mc, MC_dropout=MC_dropout, 
                                               LaPlace=LaPlace)
                    det_adv_expls.extend(advExpDet)
                    stoch_adv_expls.extend(advExpSto)

            elif attk == 'JSMA':
                theta = 0.031
                for gamma in gammas: 
                    attack = SaliencyMapMethod(classifier=classifiers[0], theta=theta, gamma=gamma, batch_size=batch_size//4) 
                    #0.01
                    det_attk = SaliencyMapMethod(classifier=classifiers[1], theta=theta, gamma=gamma, batch_size=batch_size//4)
                    #0.01
                    print("Attack: ", attk, " Gamma: ", gamma, " Theta: ", theta)
                    advExpDet, advExpSto = attack_and_report(classifiers, det_attk, attack, X_advSrc, y_advSrc, 
                                                             NOTA_attk=NOTA_attk, 
                                               NOTA_def=NOTA_def, num_class=num_class, n_mc=n_mc, MC_dropout=MC_dropout, 
                                               LaPlace=LaPlace, JSMA=True)
                    det_adv_expls.extend(advExpDet)
                    stoch_adv_expls.extend(advExpSto)

            else:
                print("Please confirm that you specified a supported attack.")
                break

            save_adversarials(det_adv_expls=det_adv_expls, stoch_adv_expls=stoch_adv_expls, y_true_labels=y_advSrc,
                              attk_list=attk_list, wbx_test_model=wbx_test_model, mod_type=mod_type, NOTA_attk=NOTA_attk)
            det_adv_expls = []
            stoch_adv_expls = []
            print("########### Group ", batch+1, " Saved! #############", flush=True)
    #if LaPlace:
    #    la.save('./models2/'+ wbx_test_model[17:] + mod_type)
    #    print('successfully saved your LaPlace Model')
        
if __name__ == "__main__":
    main()   
    
    
