#!/usr/bin/env python3
import tensorflow as tf
import eagerpy as ep
from foolbox import TensorFlowModel, accuracy, samples
import foolbox.attacks as f_attacks
from foolbox.attacks import LinfPGD
import DonaldDuckDataset
import DonaldDuckModel
import DonaldDuckConv
import DonaldDuckFunc
from DonaldDuckAttack import DonaldDuckAttack
import numpy as np
import os
from foolbox.criteria import TargetedMisclassification, Misclassification

attackMethods={
        'BIM':{
            'L1':{
                'method':f_attacks.L1BasicIterativeAttack(),
                'epsilon':{
                    'mnist':[20,50],
                    'cifar10':[5,10],
                    'fashion':[20,50]
                }
            },
            'L2':{
                'method':f_attacks.L2BasicIterativeAttack(),
                'epsilon':{
                    'mnist':[2,5],
                    'cifar10':[0.1,0.3],
                    'fashion':[2,5]
                }
            },
            'Linf':{
                'method':f_attacks.LinfBasicIterativeAttack(),
                'epsilon':{
                    'mnist':[0.1,0.3],
                    'cifar10':[0.005,0.01],
                    'fashion':[0.01,0.05]
                }
            },
        },
        'PGD':{
            'L1':{
                'method':f_attacks.L1ProjectedGradientDescentAttack(),
                'epsilon':{
                    'mnist':[10,20],
                    'cifar10':[5,10],
                    'fashion':[10,20]
                }
            },
            'L2':{
                'method':f_attacks.L2ProjectedGradientDescentAttack(),
                'epsilon':{
                    'mnist':[1,2],
                    'cifar10':[0.1,0.3],#
                    'fashion':[1,2]
                }
            },
            'Linf':{
                'method':f_attacks.LinfProjectedGradientDescentAttack(),
                'epsilon':{
                    'mnist':[0.1,0.3],
                    'cifar10':[0.005,0.01],#,
                    'fashion':[0.01,0.05]
                }
            },
        },
        'FGSM':{
            'L1':{
                'method':f_attacks.L1FastGradientAttack(),
                'epsilon':{
                    'mnist':[20,50],
                    'cifar10':[5,10],#
                    'fashion':[20,50]
                }
            },
            'L2':{
                'method':f_attacks.L2FastGradientAttack(),
                'epsilon':{
                    'mnist':[2,5],
                    'cifar10':[0.3,1],#
                    'fashion':[5]
                }
            },
            'Linf':{
                'method':f_attacks.LinfFastGradientAttack(),
                'epsilon':{
                    'mnist':[0.1],
                    'cifar10':[0.01,0.05],
                    'fashion':[0.01,0.05]
                }
            },
        },
       'DeepFool':{
            'L2':{
                'method':f_attacks.L2DeepFoolAttack(steps = 100,overshoot=1e-2,loss="logits"),
                'epsilon':{
                    'mnist':[0],
                    'cifar10':[0],
                    'fashion':[0]
                }
            },
            'Linf':{
                'method':f_attacks.LinfDeepFoolAttack(steps = 100,overshoot=1e-2,loss="logits"),
                'epsilon':{
                    'mnist':[0],
                    'cifar10':[0],
                    'fashion':[0]
                }
            },
        },
        'CW':{
            'L2':{
                'method':f_attacks.L2CarliniWagnerAttack(binary_search_steps = 10, steps = 10000, stepsize = 1e-2,),
                'epsilon':{
                    'mnist':[100],
                    'cifar10':[100],
                    'fashion':[100]
                }
            }
        },
    }

class FoolboxAttack(DonaldDuckAttack):
    def create_adversarial(
            self,
            images,
            labels,
            y_test,
            attackMethod,
            attack_name='',
            epsilons=1
    ):
        images = tf.constant(images)
        tar_labels = tf.constant((np.argmax(labels, axis=1)+3)%10)
        labels = tf.constant(np.argmax(labels, axis=1))
        y_test = tf.constant(np.argmax(y_test, axis=1))

        images = ep.astensors(images)[0]
        labels = ep.astensors(labels)[0]
        tar_labels = ep.astensors(tar_labels)[0]
        print(accuracy(self.fmodel, images, labels), end=' ')
        print(accuracy(self.fmodel, images, y_test), end=' ')

        criterion_un = Misclassification(labels)
        attack = attackMethod
        advs, _, success = attack(self.fmodel,images, criterion_un, epsilons=epsilons)
          
        # print(success)

        adv_examples = advs.numpy()
        adv_examples = self.model.dataset.clip(adv_examples)
        perturbation = adv_examples-images

        print(accuracy(self.fmodel, ep.astensors(adv_examples)[0], labels))
        return adv_examples
        
    def create_adversarial_pattern(
            self,
            attackMethod,
            attack_name='',
            epsilons=1
    ):
        preprocessing = dict()
        self.fmodel=TensorFlowModel(
            self.model.model,
            bounds=(0, 1),
            preprocessing=preprocessing
        )
        self.attack_name=attack_name
        print(attack_name)
        
        self.adv_examples=np.zeros_like(self.images)
        for idx in range(int(self.images.shape[0]/100)):
            self.adv_examples[100*idx:100*(idx+1)]=self.create_adversarial(
                    self.images[100*idx:100*(idx+1)],
                    self.labels[100*idx:100*(idx+1)],
                    self.y_test[100*idx:100*(idx+1)],
                    attackMethod=attackMethod,
                    attack_name=attack_name,
                    epsilons=epsilons
            )
        labels = tf.constant(np.argmax(self.labels, axis=1))
        labels = ep.astensors(labels)[0]
        a=np.argmax(self.model.model.predict(self.adv_examples), axis=1)
        b=np.argmax(self.labels, axis=1)
        print(round(np.sum(a == b) / len(a),4))

        return self.adv_examples, self.images, self.labels


if __name__ == "__main__":
    print()