import math
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import keras
from keras import layers, metrics, losses, callbacks
from keras.layers import Conv2DTranspose, Conv2D
from keras.optimizers import Adam

from src.geo_conv import GeoConv2D, GeoConv2DTranspose
from evaluation.cvae.gesture_2012.dataset_loader import load_dataset
from utils.layers import Sampling
from utils.models import VAE


# Set the random seed
SEED = 0
tf.random.set_seed(SEED)
np.random.seed(SEED)

# Learning rate
LR = 0.0005

# Activation
ACT = tf.nn.leaky_relu

# Convolution architecture
CONV_ARCH = Conv2D
# Transposed convolution architecture
if CONV_ARCH.__name__ == "Conv2D":
    CONV_TRANS_ARCH = Conv2DTranspose
elif CONV_ARCH.__name__ == "GeoConv2D":
    CONV_TRANS_ARCH = GeoConv2DTranspose

else:
    raise ValueError("Invalid convolution architecture.")


# Specify the dataset specifications
GESTURE_2012 = {
    'name': 'gesture_2012',
    'path': 'gesture_2012',
    'img_shape': (256, 256, 3),
    'label_shape': (36,),
    # TODO: change batch size to 64  or some suitable value
    'batch_size': 320,
    'num_epochs': 100,
}

DATASET = GESTURE_2012

RESULT_DIR = "results/cvae/"
os.makedirs(RESULT_DIR, exist_ok=True)

# Define the hyperparameters
LATENT_DIM = 8
img_shape = DATASET['img_shape']
label_shape = DATASET['label_shape']
batch_size = DATASET['batch_size']
KL_WEIGHT = 50.
BCE_WEIGHT = .1
MSE_WEIGHT = .5
MAE_WEIGHT = .5
SSIM_WEIGHT = 5000.
SHARPNESS_WEIGHT = 20000.
# BCE_WEIGHT = 1.
# MSE_WEIGHT = 0.
# MAE_WEIGHT = 0.
# SSIM_WEIGHT = 0.
# SHARPNESS_WEIGHT = 0.
# TODO: The initial value for this was 5e-9
kl_weight = 1e-7
bce_weight = 1e-2
leaky_relu = layers.LeakyReLU()
# -------------------------------
img_size = 1
for i in img_shape:
    img_size *= i
# -------------------------------

# load dataset
# -------------------------------
x_train = load_dataset(
    dataset_name=DATASET["name"],
    split="train",
    batch_size=DATASET["batch_size"],
)

x_test = load_dataset(
    dataset_name=DATASET["name"],
    split="test",
    batch_size=DATASET["batch_size"],
)
# -------------------------------


# Encoder architecture
# -------------------------------
def encoder_builder(conv_arch, latent_dim):
    # Images
    inputs = keras.Input(shape=img_shape, name="img")
    x = conv_arch(64, 3, 2, activation=ACT)(inputs)
    x = conv_arch(128, 3, 2, activation=ACT)(x)
    x = conv_arch(256, 3, 2, activation=ACT)(x)
    x = layers.Flatten()(x)

    # Labels
    labels = keras.Input(shape=label_shape[0], name="label")

    # Concatenate the processed images and labels
    x = tf.keras.layers.Concatenate()([x, labels])

    # Process the concatenated data
    x = layers.Dense(latent_dim, activation=ACT)(x)

    # Compute the mean and log variance of the latent space
    z_mean = layers.Dense(latent_dim, name="z_mean")(x)
    z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)

    # Sample using the reparameterization trick
    z = Sampling()([z_mean, z_log_var])

    # Return the encoder model
    encoder = keras.Model([inputs, labels], [z_mean, z_log_var, z], name="encoder")
    return encoder
# ----------------------------------------


# Decoder architecture
# -------------------------------
def decoder_builder(conv_arch, conv_trans_arch, latent_dim):
    # Latent space input
    latent_inputs = keras.Input(shape=(latent_dim,))

    # Labels
    condition_inp = keras.Input(shape=label_shape[0])

    # Concatenate the latent and labels
    concatenated_cond = tf.keras.layers.Concatenate()([latent_inputs, condition_inp])

    # Reconstruct the input
    x = layers.Dense(16 * 16 * 128, activation=ACT)(concatenated_cond)
    x = layers.Reshape((16, 16, 128))(x)
    x = conv_trans_arch(128, 3, 2, padding="same", activation=ACT)(x)
    x = conv_trans_arch(64, 3, 2, padding="same", activation=ACT)(x)
    x = conv_trans_arch(32, 3, 2, padding="same", activation=ACT)(x)
    x = conv_trans_arch(16, 3, 2, padding="same", activation=ACT)(x)
    x = conv_trans_arch(8, 3, 1, padding="same", activation=ACT)(x)
    decoder_outputs = conv_arch(3, 3, 1, padding="same", activation="sigmoid")(x)

    # Return the decoder model
    decoder = keras.Model([latent_inputs, condition_inp], decoder_outputs, name="decoder")
    return decoder
# -------------------------------



# Learning rate scheduler
threshhold = 0.000001
def lr_schedule(epoch, lr):
    if epoch <= 1:
        return lr
    elif lr > threshhold:
        lr = lr * 0.97
    else:
        lr = threshhold
    return lr


SEEDS = [0, 1, 2, 3, 4]
LATENT_DIMS = [192, 128, 64]
RESULTS_DIR = 'results/cvae/gesture_2012/'
MODELS_DIR = RESULTS_DIR + 'model/'
HISTORY_DIR = RESULTS_DIR + 'history/'
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(HISTORY_DIR, exist_ok=True)
os.makedirs(MODELS_DIR + CONV_ARCH.__name__, exist_ok=True)
os.makedirs(HISTORY_DIR + CONV_ARCH.__name__, exist_ok=True)


for latent_dim in LATENT_DIMS:
    for seed in SEEDS:
        print(f"{CONV_ARCH.__name__} - LATENT_DIM: {latent_dim} - SEED: {seed}")
        # set tf and np seeds
        tf.random.set_seed(seed)
        np.random.seed(seed)
        # Build the encoder and decoder
        encoder = encoder_builder(conv_arch=CONV_ARCH, latent_dim=latent_dim)
        decoder = decoder_builder(conv_arch=CONV_ARCH, conv_trans_arch=CONV_TRANS_ARCH, latent_dim=latent_dim)
        # Instantiate the autoencoder model
        vae = VAE(
            encoder=encoder,
            decoder=decoder,
            bce_weight=BCE_WEIGHT,
            mse_weight=MSE_WEIGHT,
            mae_weight=MAE_WEIGHT,
            ssim_weight=SSIM_WEIGHT,
            sharpness_weight=SHARPNESS_WEIGHT,
            kl_weight=KL_WEIGHT,
            img_shape=img_shape
        )
        # Compile the model
        vae.compile(optimizer=Adam(learning_rate=LR))

        # Train the models
        current_model_dir = MODELS_DIR + CONV_ARCH.__name__ + \
                            "/KL_Weight(" + str(KL_WEIGHT) + ")_BCE" + \
                            str(BCE_WEIGHT) + ")_MSE(" + str(MSE_WEIGHT) + \
                            ")_MAE(" + str(MAE_WEIGHT) + ")_Sharpness(" + \
                            str(SHARPNESS_WEIGHT) +")_SSIM(" + str(SSIM_WEIGHT) + \
                            ")_Seed(" + str(SEED) + ")_Epochs(" + str(DATASET['num_epochs']) + \
                            ")_InitLR(" + str(LR) + ")_LATENT(" + str(latent_dim) \
                            + ")_SEED(" + str(seed) + ")_Large/"

        checkpoint_filepath = current_model_dir + 'weights'

        # Train the model
        vae_history = vae.fit(
            x_train,
            validation_data=x_test,
            epochs=DATASET['num_epochs'],
            batch_size=DATASET['batch_size'],
            callbacks=[
                callbacks.LearningRateScheduler(lr_schedule, verbose=0),
            ]
        )
        vae.save_weights(checkpoint_filepath)


        # convert the history.history dict to a pandas DataFrame:
        raw_history = vae_history.history
        hist_df = pd.DataFrame(raw_history)

        current_history_dir = HISTORY_DIR + CONV_ARCH.__name__ + \
                            "/KL_Weight(" + str(KL_WEIGHT) + ")_BCE" + \
                            str(BCE_WEIGHT) + ")_MSE(" + str(MSE_WEIGHT) + \
                            ")_MAE(" + str(MAE_WEIGHT) + ")_Sharpness(" + \
                            str(SHARPNESS_WEIGHT) +")_SSIM(" + str(SSIM_WEIGHT) + \
                            ")_Seed(" + str(SEED) + ")_Epochs(" + str(DATASET['num_epochs']) + \
                            ")_InitLR(" + str(LR) + ")_LATENT(" + str(latent_dim) \
                            + ")_SEED(" + str(seed) + ")_Large/"

        os.makedirs(current_history_dir, exist_ok=True)
        # save to json:
        hist_json_file = current_history_dir + 'history.json'
        with open(hist_json_file, mode='w') as f:
            hist_df.to_json(f)

        # or save to csv:
        hist_csv_file = current_history_dir + 'history.csv'
        with open(hist_csv_file, mode='w') as f:
            hist_df.to_csv(f)
# -------------------------------










# Plotting the loss curves
# -------------------------------
from utils.plot_tools import plot_histories, plot_histories_seeds
# Go over seeds 0-4 and also both models (Conv2D, GeoConv2D), and plot the mean loss curve with shaded area as std
CONV_ARCHS = [Conv2D, GeoConv2D]
SEEDS = [0, 1, 2, 3, 4]
LATENT_DIMS = [64, 128, 192]
histories = []
all_seeds_histories = {}
legend_on = True
for latent_dim in LATENT_DIMS:
        all_seeds_histories = {}
        for seed in SEEDS:
            all_seeds_histories[seed] = []
            histories = []
            for conv_architecture in CONV_ARCHS:
                current_history_dir = HISTORY_DIR + conv_architecture.__name__ + \
                            "/KL_Weight(" + str(KL_WEIGHT) + ")_BCE" + \
                            str(BCE_WEIGHT) + ")_MSE(" + str(MSE_WEIGHT) + \
                            ")_MAE(" + str(MAE_WEIGHT) + ")_Sharpness(" + \
                            str(SHARPNESS_WEIGHT) +")_SSIM(" + str(SSIM_WEIGHT) + \
                            ")_Seed(" + str(SEED) + ")_Epochs(" + str(DATASET['num_epochs']) + \
                            ")_InitLR(" + str(LR) + ")_LATENT(" + str(latent_dim) \
                            + ")_SEED(" + str(seed) + ")_Large/"
                with open(current_history_dir + 'history.csv', mode='r') as f:
                    # read the csv
                    raw_history = pd.read_csv(f)
                    histories.append(raw_history)
                    all_seeds_histories[seed].append(raw_history)

            plot_histories(
                [histories[0], histories[1]],
                [CONV_ARCHS[0].__name__, CONV_ARCHS[1].__name__], 'loss',
                RESULTS_DIR + 'Loss_KL-Beta(' + str(KL_WEIGHT) + ')_Latent-Dim(' + str(latent_dim) + ")_Seed(" + str(seed) + ')_Epochs(' + str(DATASET['num_epochs']) + ").pdf",
                validation=True,
                extracted=True,
                y_lims=[3800, 8000]
            )


        # The following code is used to plot the mean and std of data stored in all_seeds_histories
        plot_histories_seeds(
            all_seeds_histories,
            [conv_arch.__name__  for conv_arch in CONV_ARCHS],
            SEEDS,
            'loss',
            save=RESULTS_DIR + 'Loss_KL-Beta(' + str(KL_WEIGHT) + ')_Latent-Dim(' + str(latent_dim) + ')_Epochs(' + str(DATASET['num_epochs']) + ')_aggregated.pdf',
            validation=True,
            extracted=True,
            y_lims=[3800, 8000],
            legend_on=legend_on
        )
        # legend_on = False
# -------------------------------










# Display a grid of sampled digits
# -------------------------------
def plot_latent_space_2(vae, sample_point, conv_arch, seed = 0, latent_dim = 8, labels=36, invert_background=False):
    digit_size = 256
    figure = np.zeros((2 * digit_size, int(digit_size * labels / 2), 3))


    for j in range(10):
        label = tf.reshape(tf.one_hot(j, label_shape[0]), (1, label_shape[0]))
        x_decoded = vae.decoder([sample_point, label])
        digit = tf.squeeze(x_decoded[0])
        if invert_background:
            digit = tf.where(digit < 1e-1, 1, digit)
        figure[
            int(j/(labels/2)) * digit_size : int(j/(labels/2) + 1) * digit_size,
            j * digit_size : (j+1) * digit_size,
        :
        ] = digit

    for j in range(10, labels):
        label = tf.reshape(tf.one_hot(j, label_shape[0]), (1, label_shape[0]))
        x_decoded = vae.decoder([sample_point, label])
        digit = tf.squeeze(x_decoded[0])
        if invert_background:
            digit = tf.where(digit < 1e-1, 1, digit)
        figure[
           int(j/(labels/2)) * digit_size : int((j/(labels/2) + 1)) * digit_size,
            int(j%(labels/2)) * digit_size : int(j%(labels/2) + 1) * digit_size
        ] = digit

    # plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range_x = int(labels/2) * digit_size + start_range
    pixel_range_x = np.arange(start_range, end_range_x, digit_size)
    # There are two sets of xlabels, one should appear on the top and one on the bottom
    sample_range_x_first = [f"{i}" for i in range(0, 10)]
    sample_range_x_first += [chr(ord('`')+i) for i in range(1, int(label_shape[0]/2) - 9)]
    sample_range_x_second = [chr(ord('`')+i) for i in range(int(label_shape[0]/2) - 9, label_shape[0] - 9)]

    fig, ax = plt.subplots(figsize=(10, 2))
    # We do not need the y axis ticks
    ax.get_yaxis().set_visible(False)
    ax.set_xticks(pixel_range_x)
    ax.set_xticklabels(sample_range_x_second)
    ax2 = ax.secondary_xaxis('top')
    ax2.set_xticks(pixel_range_x)
    ax2.set_xticklabels(sample_range_x_first)
    plt.tight_layout()
    ax.imshow(figure, cmap="Greys_r")
    plt.show()
    # save the figure in result_dir
    result_dir = 'results/cvae/gesture_2012/'
    os.makedirs(result_dir, exist_ok=True)
    file_name = f"{conv_arch.__name__}_LatentSpace_{latent_dim}_Seed_{seed}"
    if invert_background:
        file_name += '_inverted'
    fig.savefig(result_dir + file_name + ".png")
    fig.savefig(result_dir + file_name + ".pdf", format='pdf')
    plt.close()

LATENT_DIMS = [64, 128, 192]
SEEDS = [0, 1, 2, 3, 4]

for latent_dim in LATENT_DIMS:
    np.random.seed(0)
    sample_point = np.random.normal(size=(1, latent_dim)) * .3
    for seed in SEEDS:
        geo_encoder = encoder_builder(GeoConv2D, latent_dim)
        geo_decoder = decoder_builder(GeoConv2D, GeoConv2DTranspose, latent_dim)
        geo_vae = VAE(encoder=geo_encoder,
            decoder=geo_decoder,
            bce_weight=BCE_WEIGHT,
            mse_weight=MSE_WEIGHT,
            mae_weight=MAE_WEIGHT,
            ssim_weight=SSIM_WEIGHT,
            sharpness_weight=SHARPNESS_WEIGHT,
            kl_weight=KL_WEIGHT,
            img_shape=img_shape)
        conv_encoder = encoder_builder(Conv2D, latent_dim)
        conv_decoder = decoder_builder(Conv2D, Conv2DTranspose, latent_dim)
        conv_vae = VAE(encoder=conv_encoder,
            decoder=conv_decoder,
            bce_weight=BCE_WEIGHT,
            mse_weight=MSE_WEIGHT,
            mae_weight=MAE_WEIGHT,
            ssim_weight=SSIM_WEIGHT,
            sharpness_weight=SHARPNESS_WEIGHT,
            kl_weight=KL_WEIGHT,
            img_shape=img_shape)
        # compile the models
        geo_vae.compile(optimizer=Adam(learning_rate=LR))
        conv_vae.compile(optimizer=Adam(learning_rate=LR))
        conv_model_dir = MODELS_DIR + Conv2D.__name__ + \
                            "/KL_Weight(" + str(KL_WEIGHT) + ")_BCE" + \
                            str(BCE_WEIGHT) + ")_MSE(" + str(MSE_WEIGHT) + \
                            ")_MAE(" + str(MAE_WEIGHT) + ")_Sharpness(" + \
                            str(SHARPNESS_WEIGHT) + ")_SSIM(" + str(SSIM_WEIGHT) + \
                            ")_Seed(" + str(SEED) + ")_Epochs(" + str(DATASET['num_epochs']) + \
                            ")_InitLR(" + str(LR) + ")_LATENT(" + str(latent_dim) \
                            + ")_SEED(" + str(seed) + ")_Large/"
        geo_model_dir = MODELS_DIR + GeoConv2D.__name__ + \
                         "/KL_Weight(" + str(KL_WEIGHT) + ")_BCE" + \
                         str(BCE_WEIGHT) + ")_MSE(" + str(MSE_WEIGHT) + \
                         ")_MAE(" + str(MAE_WEIGHT) + ")_Sharpness(" + \
                         str(SHARPNESS_WEIGHT) + ")_SSIM(" + str(SSIM_WEIGHT) + \
                         ")_Seed(" + str(SEED) + ")_Epochs(" + str(DATASET['num_epochs']) + \
                         ")_InitLR(" + str(LR) + ")_LATENT(" + str(latent_dim) \
                         + ")_SEED(" + str(seed) + ")_Large/"
        # load the models
        geo_vae.load_weights(geo_model_dir + "weights")
        conv_vae.load_weights(conv_model_dir + "weights")

        plot_latent_space_2(geo_vae, sample_point, GeoConv2D, seed=seed, latent_dim=latent_dim, labels=label_shape[0], invert_background=True)
        plot_latent_space_2(conv_vae, sample_point, Conv2D, seed=seed, latent_dim=latent_dim, labels=label_shape[0], invert_background=True)
# -------------------------------
