
import DonaldDuckDataset
from foolbox4attack import attackMethods
from DonaldDuckDRR import DRR
from DonaldDuckEn_De_R import En_De_R
from DonaldDuckDe_R import De_R
import tensorflow as tf
import DonaldDuckConv
import os
import numpy as np

detectModels={
     'DRR':{
        'model':DRR,
        'name':'DRR',
        'cifar10':{
            'dir':'2020_11_02_09_53_58',
            'idx':4,
            'name':'cifar10_VGG16_cifar10_0_',
        },
        'fashion':{
            'dir':'2020_11_02_20_40_26',
            'idx':6,
            'name':'fashion_CNN_fashion_5_32_0_',
        },
        'mnist':{
            'dir':'2020_11_15_19_38_53',
            'idx':5,
            'name':'mnist_CNN_mnist_5_32_0_',
        },
     },
     'HVR-P':{
        'model':En_De_R,
        'name':'HVR-P',
        'cifar10':{
            'dir':'2020_11_13_15_23_33',
            'idx':5,
            'name':'cifar10_VGG16_cifar10_0_',
        },
        'fashion':{
            'dir':'2020_11_14_20_19_02',
            'idx':4,
            'name':'fashion_CNN_fashion_5_32_0_',
        },
        'mnist':{
            'dir':'2020_11_15_16_58_16',
            'idx':5,
            'name':'mnist_CNN_mnist_5_32_0_',
        },
     },
     'HVR-L':{
        'model':En_De_R,
        'name':'HVR-L',
        'cifar10':{
            'dir':'2020_11_13_17_16_46',
            'idx':3,
            'name':'cifar10_VGG16_cifar10_0_',
        },
        'fashion':{
            'dir':'2020_11_14_20_47_35',
            'idx':1,
            'name':'fashion_CNN_fashion_5_32_0_',
        },
        'mnist':{
            'dir':'2020_11_15_17_28_04',
            'idx':5,
            'name':'mnist_CNN_mnist_5_32_0_',
        },
     },
     'HLR-P':{
        'model':De_R,
        'name':'HLR-P',
        'cifar10':{
            'dir':'2020_11_13_18_13_58',
            'idx':1,
            'name':'cifar10_VGG16_cifar10_0_',
        },
        'fashion':{
            'dir':'2020_11_14_21_31_24' ,
            'idx':1,
            'name':'fashion_CNN_fashion_5_32_0_',
        },
        'mnist':{
            'dir':'2020_11_15_17_51_12',
            'idx':1,
            'name':'mnist_CNN_mnist_5_32_0_',
        },
     },
     'HLR-L':{
        'model':De_R,
        'name':'HLR-L',
        'cifar10':{
            'dir':'2020_11_13_19_21_15' ,
            'idx':8,
            'name':'cifar10_VGG16_cifar10_0_',
        },
        'fashion':{
            'dir': '2020_11_14_22_08_02',
            'idx':1,
            'name':'fashion_CNN_fashion_5_32_0_',
        },
        'mnist':{
            'dir':'2020_11_15_17_52_21',
            'idx':3,
            'name':'mnist_CNN_mnist_5_32_0_',
        },
     },
}
def train(dgan):
    epsilons = attackMethods['BIM']['L1']['epsilon'][dataset.name][0]
    dgan.fitModel(
        clean_img_path = 'data//' + dgan.tar_model.name \
                       + '-clean-' + 'BIM' + '_' + 'L1' + '_' + str(epsilons) + '.csv',
        adv_img_path = 'data//' + dgan.tar_model.name \
                       + '-adv-' + 'BIM' + '_' + 'L1' + '_' + str(epsilons) + '.csv'
    )
    dgan.saveModelWeights()

def test(dgan):
    for ams in attackMethods:
        for am in attackMethods[ams]:
            epsilons = attackMethods[ams][am]['epsilon'][dataset.name]
            for epsilon in epsilons:
                clean_img_path = 'data//' + dgan.tar_model.name \
                                 + '-clean-' + ams + '_' + am + '_' + str(epsilon) + '.csv'#
                adv_img_path = 'data//' + dgan.tar_model.name \
                               + '-adv-' + ams + '_' + am + '_' + str(epsilon) + '.csv'
                dgan.importAdv(
                    clean_img_path=clean_img_path,
                    adv_img_path=adv_img_path,
                )
                print(ams + '_' + am + '_' + str(epsilon),end=' ')
                dgan.detect_adv(
                    img_name=ams + '_' + am + '_' + str(epsilon),
                    # plot_flag=False
                )

if __name__=='__main__':
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    physical_devices = tf.config.list_physical_devices('GPU')
    try:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        assert tf.config.experimental.get_memory_growth(physical_devices[0])
    except:
        pass
    tf.random.set_seed(
        123
    )


    # you could change dataset and victim model by comment directly on the code
    dataset=DonaldDuckDataset.CIFAR10(standardization=False)
    # dataset=DonaldDuckDataset.Fashion(standardization=False)
    # dataset=DonaldDuckDataset.MNIST(standardization=False)
    
    if dataset.name=='cifar10':
        tar_model = DonaldDuckConv.DonaldDuckVGG16(
            dataset,
            build_dir=False
        )
        tar_model.setModel()
        tar_model.load_model(
            weights_path=r'savedModels//' + 'VGG16' +
                         '_' + dataset.name+ '.h5'
        )
    else:
        conv_layers_num = 5
        init_filters = 32
        tar_model = DonaldDuckConv.DonaldDuckCNN(
             dataset,
             build_dir=False
        )
        tar_model.setModel(
             conv_layers_num=conv_layers_num,
             filters=init_filters,
             kernel_size=(3,3)
        )
        tar_model.load_model(
             weights_path=r'savedModels//'+'CNN'+'_'+dataset.name
                                      +'_'+str(conv_layers_num)
                                      +'_'+str(init_filters)+'.h5'
        )
    
    d_model_name=['DRR','HVR-P','HVR-L','HLR-P','HLR-L']
    # you could change Detector by changing 'dms'
    dms='DRR'
    dgan=detectModels[dms]['model'](
        dataset,
        batch_size=64,
        epochs=200,
        kernel_size=(3,3),
        #build_dir=False
    )
    dgan.setModel(
        tar_model=tar_model,
        skip_flag=False,
    )
    
    weight_date=detectModels[dms][dataset.name]['dir']
    weight_idx=str(detectModels[dms][dataset.name]['idx'])
    model_name=detectModels[dms][dataset.name]['name']
   
    dgan.loadWeights(
            encoder_weight_path='savedModels//'+weight_date+
                                '//weight_encoder_'+model_name+
                                weight_idx+'.h5',
            decoder_weight_path='savedModels//'+weight_date+
                                '//weight_decoder_'+model_name+
                                weight_idx+'.h5',
            disI_weight_path='savedModels//'+weight_date+
                            '//weight_dis_'+model_name+
                            weight_idx+'.h5',
    )

    # you could bypass training procedure, if you already get pre-trained model
    train(dgan=dgan)

    test(dgan=dgan)