import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation, Input, Embedding, Lambda
from tensorflow.keras.layers import AveragePooling2D, Concatenate, Add, LeakyReLU, ZeroPadding2D
from tensorflow.keras.layers import Conv2D, UpSampling2D, MaxPooling2D
from tensorflow.keras.layers import BatchNormalization, ReLU, Reshape, Conv2DTranspose
from tensorflow.keras import regularizers
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import get_custom_objects
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.layers import LeakyReLU
import DonaldDuckDataset
import DonaldDuckConv
import DonaldDuckModel
import DonaldDuckFunc
import DonaldDuckAE
from scipy import integrate

 
class En_De_R(DonaldDuckAE.AE):
    def saveModelWeights(self, idx=''):
        encoder_name = 'weight_encoder_' + self.name + '_{idx}.h5'.format(
            idx=idx
        )
        decoder_name = 'weight_decoder_' + self.name + '_{idx}.h5'.format(
            idx=idx
        )
        dis_name = 'weight_dis_' + self.name + '_{idx}.h5'.format(
            idx=idx
        )
        self.encoder.save_weights(self.saveModelPath + '/' + encoder_name)
        self.decoder.save_weights(self.saveModelPath + '/' + decoder_name)
        self.D_img.save_weights(self.saveModelPath + '/' + dis_name)

    def setModel(
            self,
            tar_model,
            skip_flag=False,
            skip_layer_num=0
    ):
        self.tar_model = tar_model
        self.tar_model.model.trainable = False

        self.name = self.dataset.name + '_' + tar_model.name + \
                    '_' + str(skip_layer_num)
        self.optimizer = Adam(0.0002, 0.5)#

        self.target=Model(
            inputs=self.tar_model.model.input,
            outputs=self.tar_model.model.layers[0 - skip_layer_num - 2].output,
            name='Target'
        )
        self.target.trainable=False

        self.label_shape = self.target.output_shape[1:]
        if self.dataset.name=='cifar10':
            self.feature_shape = (128,)#128 4
        else:
            self.feature_shape = (16,)
        self.valid_shape = (1,)
        self.clip_value=0.03

        self.encoder, self.decoder = self.build_encoder_decoder()

        self.D_img = self.build_D_image()
        self.D_img.compile(
            loss='mse',
            optimizer=self.optimizer,
            metrics='accuracy'
        )
        self.D_img.trainable = False

        imgs = Input(shape=self.input_shape)

        fake_semantics = self.encoder(imgs)

        rc_imgs = self.decoder(fake_semantics)
        rc_logits=self.target(rc_imgs)
        valid_rc = self.D_img(rc_imgs)

        self.combined = Model(
            inputs=imgs,
            outputs=[
                rc_imgs,
                rc_logits, 
                valid_rc
            ]
        )
        self.combined.compile(
            loss=[
                'mse', 
                'mse', 
                'mse',
            ],
            loss_weights=[
                100,
                0.01,
                0.1,
            ],
            optimizer=self.optimizer
        )
#        self.target.summary()
#        self.encoder.summary()
#        self.decoder.summary()
#        self.D_img.summary()
#        self.combined.summary()
        self.train_labels=self.target.predict(self.x_train)
        self.test_labels=self.target.predict(self.x_test)

    def loadWeights(
            self,
            encoder_weight_path=None,
            decoder_weight_path=None,
            disI_weight_path=None,
    ):
        if not (encoder_weight_path is None):
            self.encoder.load_weights(encoder_weight_path)
        if not (decoder_weight_path is None):
            self.decoder.load_weights(decoder_weight_path)
        if not (disI_weight_path is None):
            self.D_img.load_weights(disI_weight_path)

    def fitModel(
            self,
            clean_img_path='',
            adv_img_path=''
    ):
        self.importAdv(
            clean_img_path=clean_img_path,
            adv_img_path=adv_img_path,
        )

        valid_shape=int(self.input_shape[0]/8)
        invalid = np.zeros((self.batch_size, valid_shape,valid_shape,1))#-1
        valid = np.zeros((self.batch_size, valid_shape,valid_shape, 1)) + 1

        valid_f = np.zeros((self.batch_size, valid_shape,valid_shape, 1)) - 1
        valid_i = np.zeros((self.batch_size, valid_shape,valid_shape, 1))
        valid_r = np.zeros((self.batch_size, valid_shape,valid_shape, 1)) + 1

        zeros = np.zeros((self.batch_size, 1))

        batches = self.epochs * 100 + 1

        for batch in range(batches):
            for _ in range(1):
                idx = np.random.randint(0, self.x_train.shape[0], self.batch_size)
                imgs = self.x_train[idx]
                labels_i = self.y_train[idx]
                logits_i=self.train_labels[idx]

                semantics= self.encoder.predict(imgs)
                fake_imgs = self.decoder.predict(semantics)
                
                di_loss_1 = self.D_img.train_on_batch(fake_imgs, invalid)
                di_loss_2 = self.D_img.train_on_batch(imgs, valid)

            c_loss = self.combined.train_on_batch(
                imgs,
                [
                    imgs,
                    logits_i,
                    valid,
                ]
            )

            print(
                '\r{batch} c_loss:{c_loss} '
                'dI_loss_1:{di_loss_1} '
                'dI_loss_2:{di_loss_2} '
                    .format(
                    batch=str(batch)+'/'+str(batches),
                    di_loss_1=[round(i,2) for i in di_loss_1],
                    di_loss_2=[round(i,2) for i in di_loss_2],
                    c_loss=[round(i,3) for i in c_loss],
                ),
                end=''
            )

            if batch % 200 == 0:
                print()
                img_name = self.name + '_' + str(batch)

                self.saveModelWeights(idx=str(batch % 9))

                if clean_img_path != '':
                    self.detect_adv(
                        img_name=img_name + '_detect'
                    )

    def reconstruct(self, data):
        semantics= self.encoder.predict(data)
        rlt = self.decoder.predict(semantics)
        rlt = self.dataset.clip(rlt)
        return rlt

    def build_encoder_decoder(self):

        en_inputs_img = Input(shape=self.input_shape)
        en_x = en_inputs_img

        en_x = Conv2D(filters=64, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        en_x = Conv2D(filters=64, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        
        en_x = MaxPooling2D(pool_size=(2, 2))(en_x)

        en_x = Conv2D(filters=128, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        en_x = Conv2D(filters=128, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        
        en_x = MaxPooling2D(pool_size=(2, 2))(en_x)

        en_x = Conv2D(filters=256, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        en_x = Conv2D(filters=256, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        
        en_x = MaxPooling2D(pool_size=(2, 2))(en_x)

        en_x = Conv2D(filters=512, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        en_x = Conv2D(filters=512, kernel_size=(3,3), padding='same')(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)

        en_x = Flatten()(en_x)
        en_x = Dense(256)(en_x)
        en_x = BatchNormalization()(en_x)
        en_x = LeakyReLU(0.2)(en_x)
        en_x = Dense(np.prod(self.feature_shape))(en_x)

        encoder = Model(
            inputs=en_inputs_img,
            outputs=en_x,
            name='Encoder'
        )

        de_input = Input(shape=self.feature_shape)

        de_x = Dense(int(self.input_shape[0]/4)*int(self.input_shape[1]/4)*16)(de_input)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)
        de_x = Reshape((int(self.input_shape[0]/4),int(self.input_shape[1]/4),16))(de_x)

        de_x = Conv2D(filters=512, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)
        de_x = Conv2D(filters=512, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)

        de_x = Conv2D(filters=256, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)
        de_x = Conv2D(filters=256, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)
        
        de_x = UpSampling2D()(de_x)

        de_x = Conv2D(filters=128, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)
        de_x = Conv2D(filters=128, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)
        
        de_x = UpSampling2D()(de_x)

        de_x = Conv2D(filters=64, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)
        de_x = Conv2D(filters=64, kernel_size=(3,3), padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        de_x = Activation('relu')(de_x)

#        de_x = Conv2D(filters=512, kernel_size=(3,3), padding='same')(de_x)
#        de_x = BatchNormalization()(de_x)
#        de_x = Activation('relu')(de_x)
#        de_x = Conv2D(filters=512, kernel_size=(3,3), padding='same')(de_x)
#        de_x = BatchNormalization()(de_x)
#        de_x = Activation('relu')(de_x)
#
#        de_x = Conv2D(filters=16, kernel_size=(3,3), padding='same')(de_x)
#        de_x = BatchNormalization()(de_x)
#        de_x = Activation('relu')(de_x)
#        de_x = Conv2D(filters=16, kernel_size=(3,3), padding='same')(de_x)
#        de_x = BatchNormalization()(de_x)
#        de_x = Activation('relu')(de_x)

        de_x = Conv2D(self.input_shape[-1], self.kernel_size, padding='same')(de_x)
        de_x = BatchNormalization()(de_x)
        #de_x = Activation('tanh')(de_x)
        de_x = Activation('sigmoid')(de_x)

        decoder = Model(
            inputs=de_input,
            outputs=de_x,
            name='Decoder'
        )

        return encoder, decoder

    def build_D_image(self):
        model = Sequential()
        model.add(Conv2D(64, (3, 3), padding='same',
                         input_shape=self.input_shape))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(Conv2D(64, (3, 3), padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Conv2D(128, (3, 3), padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(Conv2D(128, (3, 3), padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Conv2D(256, (3, 3), padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(Conv2D(256, (3, 3), padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Conv2D(512, (3, 3), padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(Conv2D(512, (3, 3), padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        model.add(Conv2D(1, (3, 3), padding='same'))

        model.add(Activation('sigmoid'))
        return model

if __name__ == '__main__':
    print()