#!/usr/bin/env python
# -*- coding: utf-8 -*-

import tensorflow as tf
import matplotlib as mpl
import matplotlib.pyplot as plt
import DonaldDuckModel
import DonaldDuckDataset
import DonaldDuckConv
import numpy as np
import pandas as pd
import DonaldDuckFunc

def persent(a,b):
    count = 0
    for idx in range(len(a)):
        if a[idx] == b[idx]:
            count += 1
    count /= len(a)
    return count

adp_epsilons={
    'cifar10':{
        1:[5, 10, 15],
        2:[0.1, 0.3, 0.5],
        np.inf:[0.005, 0.01, 0.02],
    },
    'fashion':{
        1:[10, 20],
        2:[1, 2,],
        np.inf:[0.01, 0.05],
    },
    'mnist':{
        1:[5, 10, 15, 20],
        2:[1, 2, 3],
        np.inf:[0.1, 0.3, 0.5],
    },
}

dataset=DonaldDuckDataset.CIFAR10(standardization=False)

class DonaldDuckAttack():
    def __init__(
            self,
            model=DonaldDuckConv.DonaldDuckCNN(dataset, build_dir=False),
            advNum=100
    ):
        self.model=model
        self.advNum=advNum

        if advNum>=10000:
            self.images=self.model.x_test
            self.y_test=self.model.y_test
        else:
            idx = np.random.randint(low=0, high=self.model.x_test.shape[0], size=self.advNum)
            self.images=self.model.x_test[idx]
            self.y_test=self.model.y_test[idx]
        self.labels=self.model.model.predict(self.images)#

        self.targetLabels=np.argmax(self.labels, axis=1)
        self.targetLabels=np.random.randint(low=0, high=10, size=self.advNum)
        self.targetLabels=tf.keras.utils.to_categorical(self.targetLabels, self.targetLabels.max()+1)
        # np.random.shuffle(self.targetLabels)
        self.perturbation=None
        self.adv_examples=None
        self.attack_name=None
        
        # a=np.argmax(self.model.model.predict(self.model.x_test), axis=1)
        # b=np.argmax(self.model.y_test, axis=1)
        # print(round(np.sum(a == b) / len(a),4), end=' ')


    def test_adv(self, dgan=None):

        DonaldDuckFunc.plot_imgs(
            self.model.dataset.clip(self.adv_examples),
            img_show_flag=False,
            img_save_flag=True,
            img_path='2.png'
        )

        a = np.argmax(self.model.model.predict(self.model.x_test), axis=1)
        b = np.argmax(self.model.y_test, axis=1)
        bn_acc = round(np.sum(a == b) / len(a), 4)
        print('bn_acc', end=' ')
        print(bn_acc, end=' ')

        a = np.argmax(self.model.model.predict(self.adv_examples), axis=1)
        b = np.argmax(self.model.model.predict(self.images), axis=1)
        labels = np.argmax(self.labels, axis=1)
        adv_acc = round(np.sum(a == labels) / len(a), 3)
        print('adv_acc', end=' ')
        print(adv_acc, end=' ')

        epsilon = round(np.mean(DonaldDuckFunc.cal_distance(self.adv_examples, self.images, lp=2)), 4)
        print('epsilon', end=' ')
        print(epsilon)

        if dgan is not None:
            print('detect', end=' ')
            dgan.testAdv = self.adv_examples
            dgan.testClean = self.images
            dis_raw, dis_adv, _, _ = dgan.detect_adv(
                img_name='Adapt' + '_' + 'Linf' + '_' + str(epsilon) + '_' + DonaldDuckFunc.getTimeStamp(),
                # plot_flag=False
            )
            print()
            return adv_acc, dis_raw, dis_adv
        return adv_acc

    def toCsv(self, tar_data, fileName):
        filePara = pd.DataFrame(data=tar_data)
        filePara.to_csv(fileName)

    def saveExamples(self, model_name):
        self.toCsv(self.images.reshape(self.advNum, int(self.images.size/self.advNum)), r'data/'+model_name+'-clean-'+self.attack_name+'.csv')
        self.toCsv(self.adv_examples.reshape(self.advNum, int(self.adv_examples.size/self.advNum)), r'data/'+model_name+'-adv-'+self.attack_name+'.csv')

if __name__=='__main__':
    DonaldDuck_a=DonaldDuckAttack(r'save_model/step/weight-relu-cifar10-cnn.h5')
