import os
import time

import pickle
import json
import tensorflow as tf
import numpy as np
import pandas as pd

from matplotlib.pyplot import figure
from modules.vae.transformvae import TransformVAE
from modules.vae.hypertorus_transformvae import HypertorusTransformVAE
from modules.vae.standard_transformvae import StandardTransformVAE
from modules.vae.hypercylinder_transformvae import HypercylinderTransformVAE
from modules.vae.shapeculturevae import ShapeCultureVAE


def create_folder(folder_path):
    os.makedirs(folder_path, exist_ok=True)


def convert_dictionary_entries_to_string(dictionary: dict) -> dict:
    for key in dictionary.keys():
        dictionary[key] = str(dictionary[key])
    return dictionary


AVAILABLE_MODELS = [TransformVAE.__name__, HypertorusTransformVAE.__name__, StandardTransformVAE.__name__,
                    ShapeCultureVAE.__name__, HypercylinderTransformVAE.__name__]


class Experiment:
    """
    Experiment class for saving locally relevant information. The folder structure generated for an experiment
    is as follows
    |-path
    ||-experiment_name
    |||-timestamp -> timestamp generated at the moment of running start_experiment() method
    ||||-images -> saving relevant images
    ||||-parameters -> saving dictionaries with parameters of the experiment
    ||||-model_parameters -> saving model parameters to re-create model
    ||||-tensorboard -> saving tensorboard logs
    ||||-weights -> saving weights from Keras model
    ||||-logs -> saving arbitrary logs
    ||||-gifs -> saving created gifs
    ||||-info -> saving experiment information
    """

    def __init__(self, path, experiment_name):
        # Experiment properties
        self.path = path
        self.experiment_name = experiment_name
        self.experiment_path = os.path.join(self.path, self.experiment_name)
        self.timestamp = None
        # Parameters
        self.model_name = None
        self.exp_parameters = {}
        self.model_parameters = {}

    @property
    def timestamp_path(self):
        assert self.timestamp is not None, "Experiment has not been started"
        path = os.path.join(self.experiment_path, self.timestamp)
        return path

    @property
    def images_path(self):
        return os.path.join(self.timestamp_path, "images")

    @property
    def weights_path(self):
        return os.path.join(self.timestamp_path, "weights")

    @property
    def tensorboard_path(self):
        return os.path.join(self.timestamp_path, "tensorboard")

    @property
    def parameters_path(self):
        return os.path.join(self.timestamp_path, "parameters")

    @property
    def model_parameters_path(self):
        return os.path.join(self.timestamp_path, "model_parameters")

    @property
    def gifs_path(self):
        return os.path.join(self.timestamp_path, "gifs")

    @property
    def logs_path(self):
        return os.path.join(self.timestamp_path, "logs")

    @property
    def info_path(self):
        return os.path.join(self.timestamp_path, "info")

    # %%%%%%%%%%%%%%% SETTING FUNCTIONS %%%%%%%%%%%%%%%

    def start_experiment(self, parameters: dict):
        self.timestamp = time.strftime("%Y-%m-%d-%H-%M-%S_")
        create_folder(self.timestamp_path)
        self.update_experiment_parameters(parameters)

    def update_experiment_parameters(self, parameters: dict):
        print("Updating parameters...")
        self.exp_parameters.update(parameters)
        create_folder(self.parameters_path)
        parameters_file = os.path.join(self.parameters_path, self.timestamp + "_parameters_" + ".json")
        with open(parameters_file, 'w') as f:
            json.dump(convert_dictionary_entries_to_string(self.exp_parameters), f)

    def set_model_parameters(self, model_parameters, model_type):
        assert model_type in AVAILABLE_MODELS, "Model type not available for saving parameters"
        self.model_parameters = {model_type: model_parameters}
        create_folder(self.model_parameters_path)
        model_parameters_file = os.path.join(self.model_parameters_path,
                                             self.timestamp + "_modelparameters_" + ".pickle")
        print("Updated experiment parameters with model parameters ", self.model_parameters)
        with open(model_parameters_file, 'wb') as f:
            pickle.dump(self.model_parameters, f, protocol=pickle.HIGHEST_PROTOCOL)

    def save_model_weights(self, model: tf.keras.Model, name: str = ""):
        """
        Save the weights of the model in the weights folder with the appropriate time stamp.
        Args:
            model: Keras Model whose weights are saved.
            name: extra name added to the saved weights
        """
        create_folder(self.weights_path)
        # %%%%%%%%%% WEIGHT FILE %%%%%%%%%%
        print("Saving model weights in {} as {}".format(self.weights_path,  self.timestamp + "_" + name + "_" + ".h5"))
        weights_file = os.path.join(self.weights_path, self.timestamp + "_" + name + "_" + ".h5")
        # %%%%%%%%%% SAVING %%%%%%%%%%
        model.save_weights(weights_file)

    def set_tensorboard_cb(self) -> tf.keras.callbacks.Callback:
        """
        Sets the tensorboard callback function and creates the corresponding folder in the experiment path
        Returns:
            - tensorboard_cb : tensorboard callback
        """
        create_folder(self.tensorboard_path)
        # Tensorboard_file
        tensorboard_file = os.path.join(self.tensorboard_path, self.timestamp)
        print("Tensorboard file created in ", tensorboard_file)
        tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_file)
        return tensorboard_cb

    def save_figure(self, fig: figure, image_name="plot", extension=".png", sub_dir=None):
        if sub_dir is None:
            path = self.images_path
        else:
            path = os.path.join(self.images_path, sub_dir)
        create_folder(path)
        save_path = os.path.join(path, self.timestamp + "_" + image_name + "_" + extension)
        fig.savefig(save_path, bbox_inches="tight")
        print("Image saved in {}".format(save_path))

    def save_logs_array(self, logs: np.ndarray, name: str = ""):
        create_folder(self.logs_path)
        log_file = os.path.join(self.logs_path, self.timestamp + "_" + name + "_.npy")
        np.save(log_file, logs)

    def set_checkpoint_cb(self, monitor="loss", mode_checkpoint="min",
                          period_checkpoint=1) -> tf.keras.callbacks.Callback:
        """
        Create model checkpoint callback
        Args:
            monitor:
            mode_checkpoint:
            period_checkpoint:
        Returns:

        """
        create_folder(self.model_parameters_path)
        weights_file = os.path.join(self.weights_path, self.timestamp + "_checkpoint_" + ".h5")
        print("MODEL CHECKPOINT FILE", weights_file)
        model_cbk = tf.keras.callbacks.ModelCheckpoint(os.path.join(self.weights_path, weights_file),
                                                       monitor=monitor, verbose=0, save_best_only=True,
                                                       save_weights_only=True, mode=mode_checkpoint,
                                                       period=period_checkpoint)
        return model_cbk

    def dataframe_saved_parameters(self):
        timestamp_list = os.listdir(self.experiment_path)
        if len(timestamp_list) == 0:
            print("No previous experiments found")
            output_df = None
        else:
            parameter_dict_list = []
            for num_timestamp, self.timestamp in enumerate(timestamp_list):
                print(f"\n\t----- TIMESTAMP NUMBER {num_timestamp} : {self.timestamp} -----\n")
                parameters_file = os.path.join(self.parameters_path, self.timestamp + "_parameters_" + ".json")

                with open(parameters_file, 'r') as f:
                    # Reading from json file
                    parameter_dictionary = json.load(f)
                    parameter_dictionary.update({"timestamp": self.timestamp})
                parameter_dict_list.append(parameter_dictionary)
            output_df = pd.DataFrame(parameter_dict_list)
            self.timestamp = None
        print(output_df)
        return output_df

    def interactive_select_target_previous_experiment(self):
        df = self.dataframe_saved_parameters()
        num_experiment = int(input(" Indicate the number of the timestamp experiment that you want to load: "))
        print(f"\n Selected experiment number {num_experiment}")
        df_selection = df[df.index == num_experiment]
        self.timestamp = df_selection.iloc[0]["timestamp"]
        self.exp_parameters = df_selection.to_dict("records")
        # self.load_parameters_name()
        return df_selection

    def select_target_previous_experiment(self, num_experiment):
        df = self.dataframe_saved_parameters()
        print(f"\n Selected experiment number {num_experiment}")
        df_selection = df[df.index == num_experiment]
        self.timestamp = df_selection.iloc[0]["timestamp"]
        self.exp_parameters = df_selection.to_dict("records")
        # self.load_parameters_name()
        return df_selection

    def load_parameters_name(self):
        model_parameters_file = os.path.join(self.model_parameters_path,
                                             self.timestamp + "_modelparameters_" + ".pickle")
        with open(model_parameters_file, 'rb') as handle:
            model_dictionary = pickle.load(handle)

        self.model_parameters = model_dictionary

    def recreate_model(self):
        # model_parameters_file = os.path.join(self.model_parameters_path,
        #                                      self.timestamp + "_modelparameters_" + ".pickle")
        # with open(model_parameters_file, 'rb') as handle:
        #     model_dictionary = pickle.load(handle)
        #
        # model_name = list(model_dictionary.keys())[0]
        # model_parameters = model_dictionary[model_name]
        # self.model_parameters = model_dictionary
        model_name = list(self.model_parameters)[0]
        assert model_name in AVAILABLE_MODELS, f"Model {model_name} not available for re-creation"
        model_class = eval(model_name)(**self.model_parameters[model_name])
        return model_class

    def load_weights(self, model: tf.keras.models.Model, name=""):
        weights_file = os.path.join(self.weights_path, self.timestamp + "_" + name + "_" + ".h5")
        model.load_weights(weights_file)
        return model
