# Import the required libraries
# -------------------------------
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
from keras import layers, metrics, losses, callbacks
from keras.layers import Conv2D, Conv2DTranspose
from keras.optimizers import Adam

from src.geo_conv import GeoConv2D, GeoConv2DTranspose
from evaluation.cvae.celeb_a.dataset_loader import load_dataset
from utils.layers import Sampling
from utils.models import VAE
import pandas as pd
# -------------------------------


# Initialise some of the constants and hyperparameters
# -------------------------------
# Set the random seed
SEED = 0
tf.random.set_seed(SEED)
np.random.seed(SEED)

# Learning rate
LR = 0.002

# Activation
ACT = tf.nn.relu

# Convolution architecture
CONV_ARCH = Conv2D
# Transposed convolution architecture
if CONV_ARCH.__name__ == "Conv2D":
    CONV_TRANS_ARCH = Conv2DTranspose
elif CONV_ARCH.__name__ == "GeoConv2D":
    CONV_TRANS_ARCH = GeoConv2DTranspose
else:
    raise ValueError("Invalid convolution architecture.")

# Specify the dataset specifications
CELEB_A = {
    'name': 'celeb_a',
    'path': 'sign_lang_mnist',
    'img_shape': (218, 178, 3),
    'label_shape': (40,),
    'batch_size': 128,
    'num_epochs': 30,
}

DATASET = CELEB_A

RESULT_DIR = "results/cvae/"
os.makedirs(RESULT_DIR, exist_ok=True)

# Define the hyperparameters
LATENT_DIMS = [256, 384, 512]
img_shape = DATASET['img_shape']
label_shape = DATASET['label_shape']
batch_size = DATASET['batch_size']
KL_WEIGHT = 50.
BCE_WEIGHT = .1
MSE_WEIGHT = .5
MAE_WEIGHT = .5
SSIM_WEIGHT = 5000.
SHARPNESS_WEIGHT = 20000.
# -------------------------------
img_size = 1
for i in img_shape:
    img_size *= i
# -------------------------------


# Load the dataset
# -------------------------------
x_train = load_dataset(
    dataset_name=DATASET["name"],
    split="train",
    batch_size=DATASET["batch_size"],
)

x_test = load_dataset(
    dataset_name=DATASET["name"],
    split="test",
    batch_size=DATASET["batch_size"],
)
# -------------------------------


# Encoder architecture
# -------------------------------
def encoder_builder(conv_arch, latent_dim):
    # Images
    inputs = keras.Input(shape=img_shape, name="img")
    x = conv_arch(64, 3, 2, activation=ACT)(inputs)
    x = conv_arch(128, 3, 2, activation=ACT)(x)
    x = conv_arch(256, 3, 2, activation=ACT)(x)
    x = layers.Flatten()(x)

    # Labels
    labels = keras.Input(shape=label_shape[0], name="label")

    # Concatenate the processed images and labels
    x = tf.keras.layers.Concatenate()([x, labels])

    # Process the concatenated data
    x = layers.Dense(latent_dim, activation=ACT)(x)

    # Compute the mean and log variance of the latent space
    z_mean = layers.Dense(latent_dim, name="z_mean")(x)
    z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)

    # Sample using the reparameterization trick
    z = Sampling()([z_mean, z_log_var])

    # Return the encoder model
    encoder = keras.Model([inputs, labels], [z_mean, z_log_var, z], name="encoder")
    return encoder
# ----------------------------------------


# Decoder architecture
# -------------------------------
def decoder_builder(conv_arch, conv_trans_arch, latent_dim):
    # Latent space input
    latent_inputs = keras.Input(shape=(latent_dim,))

    # Labels
    condition_inp = keras.Input(shape=label_shape[0])

    # Concatenate the latent and labels
    concatenated_cond = tf.keras.layers.Concatenate()([latent_inputs, condition_inp])

    # Reconstruct the input
    x = layers.Dense(16 * 16 * 128, activation=ACT)(concatenated_cond)
    x = layers.Reshape((16, 16, 128))(x)
    x = conv_trans_arch(128, 3, 2, padding="same", activation=ACT)(x)
    x = conv_trans_arch(64, 3, 2, padding="same", activation=ACT)(x)
    x = conv_trans_arch(32, 3, 2, padding="same", activation=ACT)(x)
    x = conv_trans_arch(16, 3, 2, padding="same", activation=ACT)(x)
    x = conv_trans_arch(8, 3, 1, padding="same", activation=ACT)(x)
    x = conv_arch(3, 3, 1, padding="same", activation="sigmoid")(x)

    decoder_outputs = layers.Resizing(*img_shape[:-1])(x)
    # Return the decoder model
    decoder = keras.Model([latent_inputs, condition_inp], decoder_outputs, name="decoder")
    return decoder
# -------------------------------

# Defining some paths
RESULTS_DIR = 'results/cvae/celeb_a/'
MODELS_DIR = RESULTS_DIR + 'model/'
HISTORY_DIR = RESULTS_DIR + 'history/'
IMAGE_DIR = RESULTS_DIR + 'images/'
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(HISTORY_DIR, exist_ok=True)
os.makedirs(IMAGE_DIR, exist_ok=True)
os.makedirs(MODELS_DIR + CONV_ARCH.__name__, exist_ok=True)
os.makedirs(HISTORY_DIR + CONV_ARCH.__name__, exist_ok=True)







# Training the VAE
# -------------------------------
#Defining lr_sceduler
def lr_schedule(epoch, lr):
    if epoch <= 1:
        return lr
    elif lr > 0.0001:
        lr = round(lr * 0.98, 6)
    else:
        lr = 0.0001
    return lr

for latent_dim in LATENT_DIMS:
    for seed in [0, 1, 2, 3, 4]:
        tf.random.set_seed(seed)
        np.random.seed(seed)

        # Build the model
        encoder = encoder_builder(conv_arch=CONV_ARCH, latent_dim=latent_dim)
        decoder = decoder_builder(conv_arch=CONV_ARCH, conv_trans_arch=CONV_TRANS_ARCH, latent_dim=latent_dim)
        vae = VAE(
            encoder=encoder,
            decoder=decoder,
            bce_weight=BCE_WEIGHT,
            mse_weight=MSE_WEIGHT,
            mae_weight=MAE_WEIGHT,
            ssim_weight=SSIM_WEIGHT,
            sharpness_weight=SHARPNESS_WEIGHT,
            kl_weight=KL_WEIGHT,
            img_shape=img_shape
        )

        # Compile the model
        vae.compile(optimizer=Adam(learning_rate=LR))

        current_model_dir = MODELS_DIR + CONV_ARCH.__name__ + \
                            "/KL_Weight(" + str(KL_WEIGHT) + ")_BCE" + \
                            str(BCE_WEIGHT) + ")_MSE(" + str(MSE_WEIGHT) + \
                            ")_MAE(" + str(MAE_WEIGHT) + ")_Sharpness(" + \
                            str(SHARPNESS_WEIGHT) +")_SSIM(" + str(SSIM_WEIGHT) + \
                            ")_Seed(" + str(seed) + ")_Epochs(" + str(DATASET['num_epochs']) + \
                            ")_InitLR(" + str(LR) + ")_LATENT(" + str(latent_dim) +")_small/"
        os.makedirs(current_model_dir, exist_ok=True)

        # Train a CNN Model based on the current conv_architecture
        print(f"{CONV_ARCH.__name__} - KL_Weight: %s - Latent Dim: %s - Seed: %s" % (KL_WEIGHT, latent_dim, seed))
        # Train the model
        vae_history = vae.fit(
            x_train,
            validation_data=x_test,
            epochs=DATASET['num_epochs'],
            batch_size=DATASET['batch_size'],
            callbacks=[
                callbacks.LearningRateScheduler(lr_schedule, verbose=0)
            ]
        )

        # Save the model weights
        vae.save_weights(current_model_dir + "weights")

        # convert the history.history dict to a pandas DataFrame:
        raw_history = vae_history.history
        hist_df = pd.DataFrame(raw_history)

        current_history_dir = HISTORY_DIR + CONV_ARCH.__name__ + \
                            "/KL_Weight(" + str(KL_WEIGHT) + ")_BCE" + \
                            str(BCE_WEIGHT) + ")_MSE(" + str(MSE_WEIGHT) + \
                            ")_MAE(" + str(MAE_WEIGHT) + ")_Sharpness(" + \
                            str(SHARPNESS_WEIGHT) +")_SSIM(" + str(SSIM_WEIGHT) + \
                            ")_Seed(" + str(seed) + ")_Epochs(" + str(DATASET['num_epochs']) + \
                            ")_InitLR(" + str(LR) + ")_LATENT(" + str(latent_dim) +")_small/"
        os.makedirs(current_history_dir, exist_ok=True)
        # save to json:
        hist_json_file = current_history_dir + 'history.json'
        with open(hist_json_file, mode='w') as f:
            hist_df.to_json(f)

        # or save to csv:
        hist_csv_file = current_history_dir + 'history.csv'
        with open(hist_csv_file, mode='w') as f:
            hist_df.to_csv(f)
# -------------------------------





# Plotting the loss curves
# -------------------------------
from utils.plot_tools import plot_histories, plot_histories_seeds
# Go over seeds 0-4 and also both models (Conv2D, GeoConv2D), and plot the mean loss curve with shaded area as std
CONV_ARCHS = [Conv2D, GeoConv2D]
SEEDS = [0, 1, 2, 3, 4]
LATENT_DIMS = [256, 384, 512]
histories = []
all_seeds_histories = {}
legend_on = True
for latent_dim in LATENT_DIMS:
        all_seeds_histories = {}
        for seed in SEEDS:
            all_seeds_histories[seed] = []
            histories = []
            for conv_architecture in CONV_ARCHS:
                current_history_dir =HISTORY_DIR + conv_architecture.__name__ + \
                            "/KL_Weight(" + str(KL_WEIGHT) + ")_BCE" + \
                            str(BCE_WEIGHT) + ")_MSE(" + str(MSE_WEIGHT) + \
                            ")_MAE(" + str(MAE_WEIGHT) + ")_Sharpness(" + \
                            str(SHARPNESS_WEIGHT) +")_SSIM(" + str(SSIM_WEIGHT) + \
                            ")_Seed(" + str(seed) + ")_Epochs(" + str(DATASET['num_epochs']) + \
                            ")_InitLR(" + str(LR) + ")_LATENT(" + str(latent_dim) +")_small/"
                with open(current_history_dir + 'history.csv', mode='r') as f:
                    # read the csv
                    raw_history = pd.read_csv(f)
                    histories.append(raw_history)
                    all_seeds_histories[seed].append(raw_history)

            plot_histories(
                [histories[0], histories[1]],
                [CONV_ARCHS[0].__name__, CONV_ARCHS[1].__name__], 'loss',
                RESULTS_DIR + 'Loss_KL-Beta(' + str(KL_WEIGHT) + ')_Latent-Dim(' + str(latent_dim) + ")_Seed(" + str(seed) + ')_Epochs(' + str(DATASET['num_epochs']) + ").pdf",
                validation=True,
                extracted=True,
                y_lims=[8000, 16000]
            )


        # The following code is used to plot the mean and std of data stored in all_seeds_histories
        plot_histories_seeds(
            all_seeds_histories,
            [conv_arch.__name__  for conv_arch in CONV_ARCHS],
            SEEDS,
            'loss',
            save=RESULTS_DIR + 'Loss_KL-Beta(' + str(KL_WEIGHT) + ')_Latent-Dim(' + str(latent_dim) + ')_Epochs(' + str(DATASET['num_epochs']) + ')_aggregated.pdf',
            validation=True,
            extracted=True,
            y_lims=[8000, 14000],
            legend_on=legend_on
        )
        legend_on = False
# -------------------------------








# Plotting Generated Images
# -------------------------------
attribute_index_dict = {
    '5_o_Clock_Shadow': 0,
        'Arched_Eyebrows': 1,
        'Attractive': 2,
        'Bags_Under_Eyes': 3,
        'Bald': 4,
        'Bangs': 5,
        'Big_Lips': 6,
        'Big_Nose': 7,
        'Black_Hair': 8,
        'Blond_Hair': 9,
        'Blurry': 10,
        'Brown_Hair': 11,
        'Bushy_Eyebrows': 12,
        'Chubby': 13,
        'Double_Chin': 14,
        'Eyeglasses': 15,
        'Goatee': 16,
        'Gray_Hair': 17,
        'Heavy_Makeup': 18,
        'High_Cheekbones': 19,
        'Male': 20,
        'Mouth_Slightly_Open': 21,
        'Mustache': 22,
        'Narrow_Eyes': 23,
        'No_Beard': 24,
        'Oval_Face': 25,
        'Pale_Skin': 26,
        'Pointy_Nose': 27,
        'Receding_Hairline': 28,
        'Rosy_Cheeks': 29,
        'Sideburns': 30,
        'Smiling': 31,
        'Straight_Hair': 32,
        'Wavy_Hair': 33,
        'Wearing_Earrings': 34,
        'Wearing_Hat': 35,
        'Wearing_Lipstick': 36,
        'Wearing_Necklace': 37,
        'Wearing_Necktie': 38,
        'Young': 39,
}



# Load the model with the below parameters
# Define the specificaions of the model
LR = 0.002
latent_dim = 384
KL_WEIGHT = 50.
BCE_WEIGHT = .1
MSE_WEIGHT = .5
MAE_WEIGHT = .5
SSIM_WEIGHT = 5000.
SHARPNESS_WEIGHT = 20000.
seed = 3
CONV_ARCHS = [Conv2D, GeoConv2D]
CONV_TRANS_ARCHS = [Conv2DTranspose, GeoConv2DTranspose]

# Build the Conv2D model
conv_encoder = encoder_builder(conv_arch=Conv2D, latent_dim=latent_dim)
conv_decoder = decoder_builder(conv_arch=Conv2D, conv_trans_arch=Conv2DTranspose, latent_dim=latent_dim)
conv_vae = VAE(
    encoder=conv_encoder,
    decoder=conv_decoder,
    bce_weight=BCE_WEIGHT,
    mse_weight=MSE_WEIGHT,
    mae_weight=MAE_WEIGHT,
    ssim_weight=SSIM_WEIGHT,
    sharpness_weight=SHARPNESS_WEIGHT,
    kl_weight=KL_WEIGHT,
    img_shape=img_shape
)

# Compile the Conv2D model
conv_vae.compile(optimizer=Adam(learning_rate=LR))
# Load the weights
conv_model_dir = MODELS_DIR + Conv2D.__name__ + \
                    "/KL_Weight(" + str(KL_WEIGHT) + ")_BCE" + \
                    str(BCE_WEIGHT) + ")_MSE(" + str(MSE_WEIGHT) + \
                    ")_MAE(" + str(MAE_WEIGHT) + ")_Sharpness(" + \
                    str(SHARPNESS_WEIGHT) + ")_SSIM(" + str(SSIM_WEIGHT) + \
                    ")_Seed(" + str(seed) + ")_Epochs(" + str(DATASET['num_epochs']) + \
                    ")_InitLR(" + str(LR) + ")_LATENT(" + str(latent_dim) + ")_small/"
conv_vae.load_weights(conv_model_dir + 'weights')

# Build the GeoConv2D model
geo_encoder = encoder_builder(conv_arch=GeoConv2D, latent_dim=latent_dim)
geo_decoder = decoder_builder(conv_arch=GeoConv2D, conv_trans_arch=GeoConv2DTranspose, latent_dim=latent_dim)
geo_vae = VAE(
    encoder=geo_encoder,
    decoder=geo_decoder,
    bce_weight=BCE_WEIGHT,
    mse_weight=MSE_WEIGHT,
    mae_weight=MAE_WEIGHT,
    ssim_weight=SSIM_WEIGHT,
    sharpness_weight=SHARPNESS_WEIGHT,
    kl_weight=KL_WEIGHT,
    img_shape=img_shape
)

# Compile the GeoConv2D model
geo_vae.compile(optimizer=Adam(learning_rate=LR))
# Load the weights
geo_model_dir = MODELS_DIR + GeoConv2D.__name__ + \
                    "/KL_Weight(" + str(KL_WEIGHT) + ")_BCE" + \
                    str(BCE_WEIGHT) + ")_MSE(" + str(MSE_WEIGHT) + \
                    ")_MAE(" + str(MAE_WEIGHT) + ")_Sharpness(" + \
                    str(SHARPNESS_WEIGHT) + ")_SSIM(" + str(SSIM_WEIGHT) + \
                    ")_Seed(" + str(seed) + ")_Epochs(" + str(DATASET['num_epochs']) + \
                    ")_InitLR(" + str(LR) + ")_LATENT(" + str(latent_dim) + ")_small/"
geo_vae.load_weights(geo_model_dir + 'weights')


# Defines a function that takes a sample point and a label and plots the reconstructed image
def plot_reconstructed_image(sample_point, label, title):
    geo_x_decoded = geo_vae.decoder([sample_point, label])
    geo_image = tf.squeeze(geo_x_decoded[0])
    conv_x_decoded = conv_vae.decoder([sample_point, label])
    conv_image = tf.squeeze(conv_x_decoded[0])
    # Define a plot with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2)
    # Plot the reconstructed image from the GeoConv2D model
    ax1.imshow(geo_image)
    ax1.set_title('GeoConv2D')
    # Plot the reconstructed image from the Conv2D model
    ax2.imshow(conv_image)
    ax2.set_title('Conv2D')
    # Remove the x and y ticks
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax2.set_xticks([])
    ax2.set_yticks([])
    # Use tight layout
    plt.tight_layout()
    # Set the title
    plt.suptitle(title)
    # Show the plot
    plt.show()
    plt.close()

def plot_reconstructed_images_in_rows(sample_point, labels, title, count):
    # Plot the reconstructed image from the GeoConv2D model in one row
    # and the reconstructed image from the Conv2D model in the second row
    # there are 8 labels in each row
    fig, axes = plt.subplots(2, len(labels), figsize=(14, 4))
    # The first row has title 'GeoConv2D' and the second row has title 'Conv2D' and these titles are on the left side of the row before the first image
    # Also, the titles are rotated 90 degrees
    axes[0, 0].set_ylabel('GeoConv', rotation=90, size='medium', labelpad=10, verticalalignment='center',
                          horizontalalignment='center')
    axes[1, 0].set_ylabel('Simple Conv', rotation=90, size='medium', labelpad=10, verticalalignment='center',
                          horizontalalignment='center')
    # For each label, plot the reconstructed image from the GeoConv2D model in the first row
    # and the reconstructed image from the Conv2D model in the second row
    for i, label in enumerate(labels):
        geo_x_decoded = geo_vae.decoder([sample_point, label])
        geo_image = tf.squeeze(geo_x_decoded[0])
        conv_x_decoded = conv_vae.decoder([sample_point, label])
        conv_image = tf.squeeze(conv_x_decoded[0])
        axes[0, i].imshow(geo_image)
        axes[1, i].imshow(conv_image)
        axes[0, i].set_xticks([])
        axes[0, i].set_yticks([])
        axes[1, i].set_xticks([])
        axes[1, i].set_yticks([])
    plt.tight_layout()
    plt.savefig(IMAGE_DIR + f"Loss_KL_Beta({KL_WEIGHT})_Latent_Dim({latent_dim})_Seed_{count}.pdf")
    # plt.suptitle(title)
    plt.show()
    plt.close()


# Sample a label: To do this, we define a function that takes a list of attributes in string format and returns a 40-dimensional vector
# with the corresponding attributes set to 1 and the rest set to 0
def get_label(attributes):
    label = np.zeros(label_shape[0])
    for attribute in attributes:
        label[attribute_index_dict[attribute]] = 1
    return label

# Set np seed
np.random.seed(1)
Num_samples = 5
sample_points = []

for i in range(Num_samples):
    # Sample a (1, LATENT_DIM) vector from the normal distribution
    sample_point = np.random.normal(size=(1, latent_dim)) * .3
    sample_points.append(sample_point)


attributes_1 = ['Blond_Hair', 'Straight_Hair', 'Pale_Skin', 'No_Beard']
attributes_2 = ['Blond_Hair', 'Straight_Hair', 'Pale_Skin', 'Male']
attributes_3 = ['Young', 'Mouth_Slightly_Open', 'Chubby', 'Eyeglasses', 'Smiling', 'No_Beard']
attributes_4 = ['Young', 'Mouth_Slightly_Open', 'Chubby', 'Eyeglasses', 'Smiling', 'Male']
attributes_5 = ['Heavy_Makeup', 'Mouth_Slightly_Open', 'Pointy_Nose', 'Oval_Face', 'Smiling', 'Narrow_Eyes', 'No_Beard']
attributes_6 = ['Heavy_Makeup', 'Mouth_Slightly_Open', 'Pointy_Nose', 'Oval_Face', 'Smiling', 'Narrow_Eyes', 'Male']
attributes_7 = ['Rosy_Cheeks', 'Arched_Eyebrows', 'Big_Nose', 'Bags_Under_Eyes', 'Smiling', 'High_Cheekbones', 'Wearing_Lipstick', 'No_Beard']
attributes_8 = ['Rosy_Cheeks', 'Arched_Eyebrows', 'Big_Nose', 'Bags_Under_Eyes', 'Smiling', 'High_Cheekbones', 'Wearing_Lipstick', 'Male', 'Goatee']

attribute_set = [attributes_1, attributes_2, attributes_3, attributes_4, attributes_5, attributes_6, attributes_7, attributes_8]
labels = []
for attributes in attribute_set:
    label = get_label(attributes)
    label = tf.reshape(label, (1, label_shape[0]))
    labels.append(label)
for count, sample_point in enumerate(sample_points):
    plot_reconstructed_images_in_rows(sample_point, labels, title='Sample_Latent ' + str(count + 1), count=count)

# -------------------------------


# Similar to above snippet only using same attribute 8 times and changing the latent vector in column
# -------------------------------
def plot_reconstructed_images_of_different_seeds_in_columns(sample_points, label):
    # Plot the reconstructed image from the GeoConv2D model in one row
    # and the reconstructed image from the Conv2D model in the second row
    # there are 8 labels in each row
    fig, axes = plt.subplots(2, len(sample_points), figsize=(14, 4))
    # The first row has title 'GeoConv2D' and the second row has title 'Conv2D' and these titles are on the left side of the row before the first image
    # Also, the titles are rotated 90 degrees
    axes[0, 0].set_ylabel('GeoConv', rotation=90, size='medium', labelpad=10, verticalalignment='center',
                          horizontalalignment='center')
    axes[1, 0].set_ylabel('Simple Conv', rotation=90, size='medium', labelpad=10, verticalalignment='center',
                          horizontalalignment='center')
    # For each label, plot the reconstructed image from the GeoConv2D model in the first row
    # and the reconstructed image from the Conv2D model in the second row
    for i, sample_point in enumerate(sample_points):
        geo_x_decoded = geo_vae.decoder([sample_point, label])
        geo_image = tf.squeeze(geo_x_decoded[0])
        conv_x_decoded = conv_vae.decoder([sample_point, label])
        conv_image = tf.squeeze(conv_x_decoded[0])
        axes[0, i].imshow(geo_image)
        axes[1, i].imshow(conv_image)
        axes[0, i].set_xticks([])
        axes[0, i].set_yticks([])
        axes[1, i].set_xticks([])
        axes[1, i].set_yticks([])
    plt.tight_layout()
    plt.savefig(IMAGE_DIR + f"Loss_KL_Beta({KL_WEIGHT})_Latent_Dim({latent_dim})_OneLabelOnly.pdf")
    # plt.suptitle(title)
    plt.show()
    plt.close()

np.random.seed(1)
sample_points = []
Num_samples = 8
for i in range(Num_samples):
    # Sample a (1, LATENT_DIM) vector from the normal distribution
    sample_point = np.random.normal(size=(1, latent_dim)) * .3
    sample_points.append(sample_point)

attributes_1 = ['Rosy_Cheeks', 'Arched_Eyebrows', 'Big_Nose', 'Bags_Under_Eyes', 'Smiling', 'High_Cheekbones', 'Wearing_Lipstick', 'No_Beard']
labels = attributes_1
label = get_label(labels)
label = tf.reshape(label, (1, label_shape[0]))
plot_reconstructed_images_of_different_seeds_in_columns(sample_points, label)
# -------------------------------
