import datetime
import pickle
import torch
import os
import sys
file_dir = os.path.dirname(__file__)
sys.path.append(file_dir)
from vae import VAE

def save_model(model, training_params, task_params,name =None, directory = None):
    """
    Save VAE model
    Args:
        model (nn.Module): VAE model
        training_params (dict): dictionary of training parameters
        task_params (dict): dictionary of task parameters
        name (str): name of the model
        directory (str): directory where the model is saved
    Returns:
        name (str): name of the model
    """
    if not name:
        if not directory:
            directory = '../models/'
        name = task_params['name'] + '_' + model.vae_params['enc_architecture'] + '_' + \
            model.vae_params['prior_architecture'] + "_Z_Date_" + str(model.dim_z) + \
            datetime.datetime.now().strftime("%Y_%m_%d_T_%H_%M_%S")
        print("Saving model as " + str(name))
    else:
        if not directory:
            directory = ''
        elif directory[-1] != '/':
            directory += '/'

    model_params=model.vae_params
    state_dict_file_prior = directory + name + "_state_dict_prior.pkl"
    state_dict_file_encoder = directory + name + "_state_dict_enc.pkl"

    vae_params_file = directory + name + "_vae_params.pkl"
    task_params_file = directory + name + "_task_params.pkl"
    training_params_file = directory + name + "_training_params.pkl"

    with open(vae_params_file, "wb") as f:
        pickle.dump(model_params, f)
    with open(training_params_file, "wb") as f:
        pickle.dump(training_params, f)
    with open(task_params_file, "wb") as f:
        pickle.dump(task_params, f)

    torch.save(model.prior.state_dict(), state_dict_file_prior)
    torch.save(model.encoder.state_dict(), state_dict_file_encoder)

    return name

def load_model(name):
    """
    loads a VAE

    Args:
        name: String, path / name to where RNN is saved

    Returns:
        model: Initialized VAE
        vae_params: dictionary of model parameters
        task_params: dictionary of task parameters
        training_params: dictionary of training parameters
    """

    state_dict_file_prior = name + "_state_dict_prior.pkl"
    state_dict_file_encoder = name + "_state_dict_enc.pkl"
    params_file = name + "_vae_params.pkl"
    task_params_file = name + "_task_params.pkl"
    training_params_file = name + "_training_params.pkl"


    with open(params_file, "rb") as f:
        vae_params = pickle.load(f)
    with open(task_params_file, "rb") as f:
        task_params = pickle.load(f)
    with open(training_params_file, "rb") as f:
        training_params = pickle.load(f)
    if "share_eps" not in vae_params:
        vae_params["share_eps"] = False
    if "train_alpha" not in vae_params:
        vae_params['train_alpha'] = False
    if vae_params['prior_params']['readout_rates']==True:
        vae_params['prior_params']['readout_rates']="rates"
    if vae_params['prior_params']['activation']=='relu' and 'clipped' in vae_params and vae_params['prior_params']['clipped']:
        vae_params['prior_params']['activation']='clipped_relu'
    model = VAE(vae_params)
    #if model.prior.readout_rates==True:
    #    model.prior.readout_rates="rates"
    d= torch.load(state_dict_file_prior,map_location=torch.device('cpu'))
    for key in list(d.keys()):
        d[key.replace("latent_step", "transition")] = d.pop(key)
    model.prior.load_state_dict(d)
    model.encoder.load_state_dict(torch.load(state_dict_file_encoder,map_location=torch.device('cpu')))

    return model, vae_params, task_params, training_params
