import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation, Input, Embedding, Lambda
from tensorflow.keras.layers import AveragePooling2D, Concatenate, Add, LeakyReLU, ZeroPadding2D
from tensorflow.keras.layers import Conv2D, UpSampling2D, MaxPooling2D, LayerNormalization
from tensorflow.keras.layers import BatchNormalization, ReLU, Reshape, Conv2DTranspose
from tensorflow.keras import regularizers
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import get_custom_objects
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.layers import LeakyReLU
import DonaldDuckDataset
import DonaldDuckConv
import DonaldDuckModel
import DonaldDuckFunc
import DonaldDuckAE
from scipy import integrate


class DRR(DonaldDuckAE.AE):
    def saveModelWeights(self, idx=''):
        encoder_name = 'weight_encoder_' + self.name + '_{idx}.h5'.format(
            idx=idx
        )
        decoder_name = 'weight_decoder_' + self.name + '_{idx}.h5'.format(
            idx=idx
        )
        dis_name = 'weight_dis_' + self.name + '_{idx}.h5'.format(
            idx=idx
        )
        self.encoder.save_weights(self.saveModelPath + '/' + encoder_name)
        self.decoder.save_weights(self.saveModelPath + '/' + decoder_name)
        self.D_img.save_weights(self.saveModelPath + '/' + dis_name)

    def setModel(
            self,
            tar_model,
            skip_flag=False,
            skip_layer_num=0
    ):
        self.tar_model = tar_model
        self.tar_model.model.trainable = False

        self.name = self.dataset.name + '_' + tar_model.name + \
                    '_' + str(skip_layer_num)
        self.optimizer = Adam(0.0002, 0.5)#
        # self.optimizer = RMSprop(lr=0.00005)

        self.target=Model(
            inputs=self.tar_model.model.input,
            outputs=self.tar_model.model.layers[0 - skip_layer_num - 2].output,
            name='Target'
        )
        self.target.trainable=False

        self.train_labels=self.target.predict(self.x_train)
        self.train_soft_labels=self.layer_norm(self.train_labels)
        self.test_labels=self.target.predict(self.x_test)
        self.label_std=np.std(self.train_labels)

        self.label_shape = self.target.output_shape[1:]
        if self.dataset.name=='cifar10':
            self.feature_shape = (128,)
        else:
            self.feature_shape = (4,)
        self.valid_shape = (1,)
        self.clip_value=0.03

        self.encoder, self.decoder = self.build_encoder_decoder()

        self.D_img = self.build_D_image()
        self.D_img.compile(
            loss='mse',
            optimizer=self.optimizer,
            metrics='accuracy'
        )
        self.D_img.trainable = False

        labels_i = Input(shape=self.label_shape)
        labels_l = Input(shape=self.label_shape)
        imgs = Input(shape=self.input_shape)

        fake_semantics = self.encoder(imgs)

        rc_imgs = self.decoder([fake_semantics, labels_i])
        rc_logits=self.target(rc_imgs)
        valid_rc = self.D_img(rc_imgs)

        fake_imgs = self.decoder([fake_semantics, labels_l])
        valid_i = self.D_img(fake_imgs)
        
        hinge=tf.keras.backend.mean(tf.keras.losses.MSE(rc_imgs, fake_imgs))
        hinge=tf.keras.activations.relu(0.5-hinge)#0.35 0.5

        
        batch = tf.keras.backend.shape(fake_semantics)[0]
        epsilon = tf.keras.backend.random_normal(
            stddev=0.1,#0.1 0.3
            shape=(batch,)+self.input_shape
        )
        fake_imgs = fake_imgs + epsilon
        
        fake_logits = self.target(fake_imgs)
        soft_logits_f=LayerNormalization(
            beta_initializer='zeros', 
            gamma_initializer='ones',
            trainable=False
            )(fake_logits)
        soft_logits_r=LayerNormalization(
            beta_initializer='zeros', 
            gamma_initializer='ones',
            trainable=False
            )(rc_logits)

        self.combined = Model(
            inputs=[imgs, labels_i, labels_l],
            outputs=[
                rc_imgs, hinge,
                # rc_logits, fake_logits,
                soft_logits_r, soft_logits_f,
                valid_rc, valid_i
            ]
        )
        self.combined.compile(
            loss=[
                'mse', 'mse',
                'mse','mse',
                'mse','mse'
            ],
            loss_weights=[
                100,1,
                1,3,
                1,1
            ],
            optimizer=self.optimizer
        )
#        self.target.summary()
#        self.encoder.summary()
#        self.decoder.summary()
#        self.D_img.summary()
#        self.combined.summary()

    def loadWeights(
            self,
            encoder_weight_path=None,
            decoder_weight_path=None,
            disI_weight_path=None,
    ):
        if not (encoder_weight_path is None):
            self.encoder.load_weights(encoder_weight_path)
        if not (decoder_weight_path is None):
            self.decoder.load_weights(decoder_weight_path)
        if not (disI_weight_path is None):
            self.D_img.load_weights(disI_weight_path)

    def fitModel(
            self,
            clean_img_path='',
            adv_img_path=''
    ):
        self.importAdv(
            clean_img_path=clean_img_path,
            adv_img_path=adv_img_path,
        )

        valid_shape=int(self.input_shape[0]/8)
        invalid = np.zeros((self.batch_size, valid_shape,valid_shape,1))#-1
        valid = np.zeros((self.batch_size, valid_shape,valid_shape, 1)) + 1

        zeros = np.zeros((self.batch_size, 1))

        batches = self.epochs * 100 + 1

        print(np.std(self.x_train), end=' ')
        print(np.std(self.test_labels))

        for batch in range(batches):
            # for _ in range(1):
            idx = np.random.randint(0, self.x_train.shape[0], self.batch_size)
            imgs = self.x_train[idx]
            # labels_i = self.y_train[idx]
            logits_i=self.train_labels[idx]
            soft_logits_i=self.train_soft_labels[idx]
            imgs_n=self.add_noise(
                imgs,
                mean=0,
                std=0.3,#0.1 0.3
                eps=1
            )

            idy = np.random.randint(0, self.x_train.shape[0], self.batch_size)
            img_l=self.x_train[idy]
            # labels_l = self.y_train[idx]
            logits_l=self.train_labels[idy]
            soft_logits_l=self.train_soft_labels[idy]
            # logits_l =self.add_noise(
            #     logits_l,
            #     mean=0,
            #     std=1.5,#1.5,0.5
            #     eps=1,
            #     l_flag=True
            # )

            semantics= self.encoder.predict(imgs_n)
            fake_imgs = self.decoder.predict([semantics, logits_l])

            di_loss_1 = self.D_img.train_on_batch(fake_imgs, invalid)
            di_loss_2 = self.D_img.train_on_batch(imgs, valid)

            c_loss = self.combined.train_on_batch(
                [
                    imgs_n,
                    logits_i, logits_l,
                ],
                [
                    imgs, zeros,
                    soft_logits_i, soft_logits_l,
                    valid,valid
                ]
            )

            print(
                '\r{batch} c_loss:{c_loss} '
                'dI_loss_1:{di_loss_1} '
                'dI_loss_2:{di_loss_2} '
                    .format(
                    batch=str(batch)+'/'+str(batches),
                    di_loss_1=[round(i,2) for i in di_loss_1],
                    di_loss_2=[round(i,2) for i in di_loss_2],
                    c_loss=[round(i,3) for i in c_loss],
                ),
                end=''
            )

            if batch % 200 == 0:
                print()
                img_name = self.name + '_' + str(batch)

                self.saveModelWeights(idx=str(batch% 9))# 

                if clean_img_path != '':
                    self.detect_adv(
                        img_name=img_name + '_detect'
                    )

    def reconstruct(self, data):
        logits = self.target.predict(data)
        semantics= self.encoder.predict(data)
        if np.std(logits)<np.floor(self.label_std):
            alpha=np.ceil(np.floor(self.label_std)/np.std(logits)+1)
            logits=logits*alpha-np.mean(logits)*(alpha-1)
        rlt = self.decoder.predict([semantics, logits])
        rlt = self.dataset.clip(rlt)
        return rlt

    def build_encoder_decoder(self, skip_layer_num=5, skip_flag=False):

        en_inputs_img = Input(shape=self.input_shape)
        en_x = en_inputs_img

        en_x = Conv2D(filters=64, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        en_x = Conv2D(filters=64, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        
        en_x = MaxPooling2D(pool_size=(2, 2))(en_x)

        en_x = Conv2D(filters=128, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        en_x = Conv2D(filters=128, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        
        en_x = MaxPooling2D(pool_size=(2, 2))(en_x)

        en_x = Conv2D(filters=256, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        en_x = Conv2D(filters=256, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        
        en_x = MaxPooling2D(pool_size=(2, 2))(en_x)

        en_x = Conv2D(filters=512, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        en_x = Conv2D(filters=512, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)

        en_x = Flatten()(en_x)
        en_x = Dense(256)(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        en_x = Dense(np.prod(self.feature_shape))(en_x)

        encoder = Model(
            inputs=en_inputs_img,
            outputs=en_x,
            name='Encoder'
        )

        de_input_f = Input(shape=self.feature_shape)
        de_input_c = Input(shape=self.label_shape)

        de_x = tf.keras.layers.Concatenate()([de_input_f, de_input_c])

        de_x = Dense(int(self.input_shape[0]/4)*int(self.input_shape[1]/4)*16)(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)
        de_x = Reshape((int(self.input_shape[0]/4),int(self.input_shape[1]/4),16))(de_x)

        de_x = Conv2D(filters=512, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)
        de_x = Conv2D(filters=512, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)

        de_x = Conv2D(filters=256, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)
        de_x = Conv2D(filters=256, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)
        
        de_x = UpSampling2D()(de_x)

        de_x = Conv2D(filters=128, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)
        de_x = Conv2D(filters=128, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)
        
        de_x = UpSampling2D()(de_x)

        de_x = Conv2D(filters=64, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)
        de_x = Conv2D(filters=64, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)

#        de_x = Conv2D(filters=512, kernel_size=(3,3), padding='same')(de_x)
#        de_x = BatchNormalization()(de_x)
#        de_x = Activation('relu')(de_x)
#        de_x = Conv2D(filters=512, kernel_size=(3,3), padding='same')(de_x)
#        de_x = BatchNormalization()(de_x)
#        de_x = Activation('relu')(de_x)
#
#        de_x = Conv2D(filters=16, kernel_size=(3,3), padding='same')(de_x)
#        de_x = BatchNormalization()(de_x)
#        de_x = Activation('relu')(de_x)
#        de_x = Conv2D(filters=16, kernel_size=(3,3), padding='same')(de_x)
#        de_x = BatchNormalization()(de_x)
#        de_x = Activation('relu')(de_x)

        de_x = Conv2D(self.input_shape[-1], self.kernel_size, padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        #de_x = Activation('tanh')(de_x)
        de_x = Activation('sigmoid')(de_x)

        decoder = Model(
            inputs=[de_input_f, de_input_c],
            outputs=de_x,
            name='Decoder'
        )

        return encoder, decoder

    def build_D_image(self):
        model = Sequential()
        model.add(Conv2D(64, (3, 3), padding='same',
                         input_shape=self.input_shape))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(Conv2D(64, (3, 3), padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Conv2D(128, (3, 3), padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(Conv2D(128, (3, 3), padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Conv2D(256, (3, 3), padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(Conv2D(256, (3, 3), padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Conv2D(512, (3, 3), padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(Conv2D(512, (3, 3), padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(Conv2D(1, (3, 3), padding='same'))

        model.add(Activation('sigmoid'))
        return model

if __name__ == '__main__':
    print()