from keras.regularizers import l2
import tensorflow as tf
from keras.layers import Conv2D, Flatten, Dense, Conv2DTranspose, Lambda, Input, BatchNormalization, ReLU, Reshape
from keras.models import Model
from keras.optimizers import Adam
from models.my_layers.spectral_normalized_dense_conv import DenseSN, ConvSN2D, ConvSN2DTranspose
from models.rae import loss_functions


def get_vae_svhn_vae_architecture(input_shape, embeding_loss_weight, generator_regs, generator_reg_types,
                                   include_batch_norm, spec_norm_dec_only, recon_loss_func, verbose=True):
    return get_vae_svhn(input_shape, bottleneck_size=bottleneck_size,
                                   embeding_loss_weight=embeding_loss_weight,
                                   generator_regs=generator_regs, generator_reg_types=generator_reg_types,
                                   include_batch_norm=include_batch_norm, num_filter=num_filter,
                                   spec_norm_dec_only=spec_norm_dec_only, recon_loss_func=recon_loss_func,
                                   verbose=verbose)


def get_vae_svhn(input_shape, bottleneck_size, embeding_loss_weight, generator_regs, generator_reg_types,
                  include_batch_norm, num_filter, spec_norm_dec_only, recon_loss_func, verbose=True):
    apply_grad_pen = False
    regularization = None
    _Conv2D = Conv2D
    _Dense = Dense
    _Conv2DTranspose = Conv2DTranspose

    grad_pen_weight = None
    for i, generator_reg_type in enumerate(generator_reg_types):
        if generator_reg_type == 'l2':
            regularization = l2(generator_regs[i])
        elif generator_reg_type == 'grad_pen':
            apply_grad_pen = True
            grad_pen_weight = generator_regs[i]
        elif generator_reg_type == 'spec_norm':
            if not spec_norm_dec_only:
                _Conv2D = ConvSN2D
            #_Dense = DenseSN
            _Conv2DTranspose = ConvSN2DTranspose
        elif callable(generator_reg_type):
            regularization = generator_reg_type
        else:
            raise NotImplementedError("Sepecified type of regularization : " + generator_reg_type +
                                      " has not been implemented")

    with tf.name_scope('encoder'):
        #input_shape= (3*32*32,)
        e_in = Input(shape=input_shape, name="input_image")

        #x = Lambda(lambda x: x*2.0 - 1.0)(e_in)
        x = Flatten()(e_in)
        x = _Dense(1000, activation='linear', name='layer_1')(x)
        x = ReLU()(x)
        x = _Dense(500, activation='linear', name='layer_2')(x)
        x = ReLU()(x)
        if generator_reg_type == 'spec_norm':
            z = DenseSN(bottleneck_size, activation='linear', name='latent_z')(x)
        else:
            z = _Dense(bottleneck_size, activation='linear', name='latent_z')(x)

        #x = _Conv2D(num_filter, (4, 4), padding='same', activation='linear', strides=(2, 2))(x)
        #if include_batch_norm:
        #    x = BatchNormalization()(x)
        #x = ReLU()(x)
        #x = _Conv2D(num_filter*2, (4, 4), padding='same', activation='linear', strides=(2, 2))(x)
        #if include_batch_norm:
        #    x = BatchNormalization()(x)
        #x = ReLU()(x)
        #x = _Conv2D(num_filter*4, (4, 4), padding='same', activation='linear', strides=(2, 2))(x)
        #if include_batch_norm:
        #    x = BatchNormalization()(x)
        #x = ReLU()(x)
        #x = _Conv2D(num_filter*8, (4, 4), padding='same', activation='linear', strides=(2, 2))(x)
        #if include_batch_norm:
        #    x = BatchNormalization()(x)
        #x = ReLU()(x)
#
        #x = Flatten()(x)
        #z = _Dense(bottleneck_size, activation='linear', name='latent_z')(x)

        encoder = Model(inputs=e_in, outputs=z, name='encoder')
        print('Encoder')
        encoder.summary()

    with tf.name_scope('decoder'):
        d_in = Input(shape=(bottleneck_size,), name='decoder_noise_in')

        if generator_reg_type == 'spec_norm':
            x = DenseSN(500, activation='linear', name='layer_1')(d_in)
        else:
            x = _Dense(500, activation='linear', name='layer_1')(d_in)
        x = ReLU()(x)
        x = _Dense(1000, activation='linear', name='layer_2')(x)
        x = ReLU()(x)
        x = _Dense(3*32*32, activation='sigmoid', name='out')(x)
        x = Reshape((32, 32, 3))(x)


        #x = _Dense(8*8*1024)(d_in)
        #x = Reshape((8, 8, 1024))(x)
#
        #x = _Conv2DTranspose(num_filter*4, (4, 4), padding='same', strides=(2, 2), activation='linear',
        #                    kernel_regularizer=regularization)(x)
        #if include_batch_norm:
        #    x = BatchNormalization()(x)
        #x = ReLU()(x)
#
        #x = _Conv2DTranspose(num_filter*2, (4, 4), padding='same', strides=(2, 2), activation='linear',
        #                    kernel_regularizer=regularization)(x)
        #if include_batch_norm:
        #    x = BatchNormalization()(x)
        #x = ReLU()(x)
#
        #x = _Conv2DTranspose(input_shape[-1], (4, 4), padding='same', activation='sigmoid',
        #                    kernel_regularizer=regularization)(x)
#
        decoder = Model(inputs=d_in, outputs=x, name='decoder')
        print('Decoder')
        decoder.summary()

    with tf.name_scope('full_VAE'):
        loss_func = loss_functions.total_loss(z, beta=embeding_loss_weight, apply_grad_pen=apply_grad_pen,
                                              grad_pen_weight=grad_pen_weight, recon_loss_func=recon_loss_func)
        vae_out = decoder(encoder.outputs[0])
        vae = Model(inputs=e_in, outputs=vae_out, name='vae')
        if apply_grad_pen:
            metrics = [loss_functions.per_pix_recon_loss, loss_functions.embeddig_loss(z),
                       loss_functions.grad_pen_loss(z, None), 'mse']
        else:
            metrics = [loss_functions.per_pix_recon_loss, loss_functions.embeddig_loss(z), 'mse']
        vae.compile(optimizer=Adam(lr=1e-3, epsilon=1e-08), loss=loss_func, metrics=metrics)

        if verbose:
            vae.summary()

    return encoder, decoder, vae
