import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation, Input, Embedding, Lambda
from tensorflow.keras.layers import AveragePooling2D, Concatenate, Add, LeakyReLU, ZeroPadding2D
from tensorflow.keras.layers import Conv2D, UpSampling2D, MaxPooling2D
from tensorflow.keras.layers import BatchNormalization, ReLU, Reshape, Conv2DTranspose
from tensorflow.keras import regularizers
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import get_custom_objects
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.layers import LeakyReLU
import DonaldDuckDataset
import DonaldDuckConv
import DonaldDuckModel
import DonaldDuckFunc
from scipy import integrate


class AE(DonaldDuckModel.DonaldDuckModel):

    def batch_normal(self, data):
        data = (data - data.min()) / (data.max() - data.min())
        return data

    def importAdv(
            self,
            clean_img_path,
            adv_img_path
    ):
        img_cols = self.input_shape[0]
        img_rows = self.input_shape[1]
        channels = self.input_shape[2]

        testAdv = np.array(pd.read_csv(adv_img_path, usecols=range(1, img_cols * img_rows * channels + 1)))
        self.testAdv = testAdv.reshape((testAdv.shape[0], img_rows, img_cols, channels))

        testClean = np.array(pd.read_csv(clean_img_path, usecols=range(1, img_cols * img_rows * channels + 1)))
        self.testClean = testClean.reshape((testClean.shape[0], img_rows, img_cols, channels))

        self.perturbation = self.testAdv - self.testClean

    def detect_adv(
            self,
            img_name='',
            plot_flag=True
    ):

        re_adv = self.reconstruct(self.testAdv)
        re_clean = self.reconstruct(self.testClean)

        distance_raw = DonaldDuckFunc.cal_distance(self.testClean, re_clean, lp=2)
        distance_adv = DonaldDuckFunc.cal_distance(self.testAdv, re_adv, lp=2)
        perturbations = DonaldDuckFunc.cal_distance(self.testClean, self.testAdv, lp=2)

        re_adv_labels = self.tar_model.predict_label(re_adv)
        re_clean_labels = self.tar_model.predict_label(re_clean)
        cleanLabel = self.tar_model.predict_label(self.testClean)
        advLabel = self.tar_model.predict_label(self.testAdv)

        lower_bound = min(distance_adv.min(), distance_raw.min())
        upper_bound = max(distance_adv.max(), distance_raw.max())

        undetect_attackable_rate = []
        false_negative_rate = []

        attackable = []

        for idx in range(len(cleanLabel)):
            if (cleanLabel[idx] != advLabel[idx]) and perturbations[idx] < 5:
                # It is worth to mention that, in this method, we aim to detect adversarial examples whose semantic do
                # not be manipulated, when l_2 norm bound distance of adversarial perturbations are too large,
                # we think these attacks fail and can be detected by human eye.
                attackable.append(idx)

        if attackable != []:
            for idx in np.arange(lower_bound, upper_bound, (upper_bound - lower_bound) / 20):
                undetect = np.where(distance_adv < idx)[0]
                uar = len(np.intersect1d(undetect, attackable)) / len(attackable)  # advLabel.shape[0]
                undetect_attackable_rate.append(uar)

                false_negative = np.where(distance_raw > idx)[0]
                fpr = len(false_negative) / cleanLabel.shape[0]
                false_negative_rate.append(fpr)

        print(round(1 - integrate.trapz(undetect_attackable_rate[::-1], false_negative_rate[::-1]), 4), end=' ')
        print(round(np.sum(re_adv_labels == advLabel) / len(re_adv_labels), 4), end=' ')
        print(round(np.sum(re_adv_labels == cleanLabel) / len(re_adv_labels), 4), end=' ')
        print(round(np.sum(cleanLabel == advLabel) / len(re_adv_labels), 4), end=' ')
        print(round(np.sum(re_clean_labels == cleanLabel) / len(re_adv_labels), 4), end=' ')
        print(round(np.mean(DonaldDuckFunc.cal_distance(self.testAdv, self.testClean, lp=2)), 4))

        if plot_flag:
            self.plot_all(
                disRaw=distance_raw,
                disAdv=distance_adv,
                reClean=re_clean,
                reAdv=re_adv,
                false_negative_rate=false_negative_rate,
                undetect_attackable_rate=undetect_attackable_rate,
                img_name=img_name,
            )
        return distance_raw, distance_adv[np.where(perturbations < 5)[0]], false_negative_rate, undetect_attackable_rate

    def plot_all(
            self,
            disRaw, disAdv,
            reClean, reAdv,
            false_negative_rate, undetect_attackable_rate,
            img_name=''
    ):
        lowG = int(min(disRaw.min(), disAdv.min()))
        highG = int(max(disAdv.max(), disRaw.max()))
        step = (highG - lowG) / 20
        highG += step * 2
        lowG -= step

        sample = np.random.randint(0, self.testAdv.shape[0], 25)
        sample_reAdv = reAdv[sample]  
        sample_reClean = reClean[sample]  

        gs = GridSpec(4, 7)
        plt.plot()
        plt.title(img_name)
        leftup = plt.subplot(gs[0:2, 0:2])
        leftup.hist(
            disAdv,
            np.arange(lowG, highG, step),
            histtype='bar',
            rwidth=0.8,
            color='red',
            alpha=0.4
        )
        leftup.hist(
            disRaw,
            np.arange(lowG, highG, step),
            histtype='bar',
            rwidth=0.8,
            color='blue',
            alpha=0.4
        )
        plt.ylim(0, self.testAdv.shape[0])

        leftdown = plt.subplot(gs[2:, 0:2])
        leftdown.plot(false_negative_rate, undetect_attackable_rate)
        plt.ylim(0, 1)
        plt.xlim(0, 1)
        plt.xticks([0, 0.2, 0.4, 0.6, 0.8, 1])

        ax = {}
        cnt = 0
        for idx_x in range(0, 2, 1):
            for idx_y in range(2, 7, 1):
                ax[idx_x * 10 + idx_y] = plt.subplot(gs[idx_x, idx_y])
                if sample_reAdv.shape[3] == 1:
                    ax[idx_x * 10 + idx_y].imshow(sample_reAdv[cnt, :, :, 0], cmap='gray')
                else:
                    ax[idx_x * 10 + idx_y].imshow(sample_reAdv[cnt, :, :, :])
                cnt += 1
                ax[idx_x * 10 + idx_y].axis('off')

        cnt = 0
        for idx_x in range(2, 4, 1):
            for idx_y in range(2, 7, 1):
                ax[idx_x * 20 + idx_y] = plt.subplot(gs[idx_x, idx_y])
                if sample_reClean.shape[3] == 1:
                    ax[idx_x * 20 + idx_y].imshow(sample_reClean[cnt, :, :, 0], cmap='gray')
                else:
                    ax[idx_x * 20 + idx_y].imshow(sample_reClean[cnt, :, :, :])
                cnt += 1
                ax[idx_x * 20 + idx_y].axis('off')

        if img_name != '':
            plt.savefig(fname=self.saveImgPath + '/' + img_name + '.png')
        # plt.show()
        plt.clf()

    def add_noise(
            self,
            sample,
            eps=1,
            mean=0,
            std=1,
            l_flag=False,
            noise_only=False
    ):
        shape = sample.shape
        noise = np.random.normal(
            mean,
            std,
            shape
        )
        if noise_only:
            res = noise
        else:
            res = noise * eps + sample
        if l_flag:
            res_max = np.argmax(res, axis=-1)
            sample_max = np.argmax(sample, axis=-1)
            for idx in range(len(sample)):
                if res_max[idx] != sample_max[idx]:
                    t = res[idx][sample_max[idx]]
                    res[idx][sample_max[idx]] = res[idx][res_max[idx]]
                    res[idx][res_max[idx]] = t
        return res

    def wasserstein_loss(self, y_true, y_pred):
        return tf.keras.backend.mean(y_true * y_pred)

    def mutual_info_loss(self, c, c_given_x):
        eps = 1e-8
        conditional_entropy = tf.keras.backend.mean(
            - tf.keras.backend.sum(
                tf.keras.backend.log(c_given_x + eps) * c,
                axis=1
            )
        )
        entropy = tf.keras.backend.mean(
            - tf.keras.backend.sum(
                tf.keras.backend.log(c + eps) * c,
                axis=1
            )
        )
        return conditional_entropy + entropy

    def layer_norm(self, data):
        data_mean = np.mean(data, axis=-1)
        data_std = np.std(data, axis=-1)
        rel = ((data.T - data_mean) / data_std).T
        return rel


if __name__ == '__main__':
    dgan = AE(DonaldDuckDataset.MNIST())
    dgan.setModel()
