
from latent_vae.vae_latent import VAE, VAEModel
import argparse
import numpy as np
from jax import numpy as jnp
from jax import random
import os
from latent_vae.utils import make_output_dir
from copy import deepcopy
import pickle
def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_file', type=str, default="z_trees.npy")
    parser.add_argument('name', help="The name of the experiment and output directory.")
    parser.add_argument('--num_epochs', dest='num_epochs', type=int, default=1000)
    parser.add_argument('--batch_size', dest='batch_size', type=int, default=1000)
    parser.add_argument('--sample_size', dest='sample_size', type=int, default=10000)
    parser.add_argument('-lr', '--learning_rate', dest='learning_rate', type=float, default=0.0001)
    parser.add_argument('-ow', dest='overwrite', action='store_true')
    parser.add_argument('--layer_sizes', dest='layer_sizes', default='2048|2048|2048|2048|2048', \
                        help="Specify layer sizes for MLP (possibly others later) as integers separated by pipes. Example: 512|512|512")
    parser.add_argument('--encoder_layer_sizes', dest='encoder_layer_sizes', default='2048|2048|2048|2048|2048', \
                        help="Specify layer sizes for MLP (possibly others later) as integers separated by pipes. Example: 512|512|512")
    parser.add_argument('--latent_dim', dest='latent_dimension', type=int, default=20)
    parser.add_argument('--seed', dest='seed', type=int, default=0)
    parser.add_argument('--state_dict', dest='state_dict', default=None)
    parser.add_argument('--prev_dict', dest='prev_dict', default=None)
    parser.add_argument('-cfd', '--copy_flow_dataset', dest='copy_flow_dataset', action='store_true')
    parser.add_argument('-ws', '--warm_start', action='store_true')
    parser.add_argument('-e', '--epsilon', type=float, default=0.)
    parser.add_argument('-tdv', dest='tunable_decoder_var', action='store_true')
    parser.add_argument('--sample', dest='sample', type=str, default="")
    parser.add_argument('--refine', dest='refine', type=str, default="")
    parser.add_argument('--refine_type', dest='refine_type', type=str, default="", choices=["whole", "inner", "outer", ""])
    parser.add_argument('--square', dest='square', action='store_true')
    parser.add_argument('--encode', dest='encode', action='store_true')
    parser.add_argument('--output_latent_file', type=str, default="latent_samples.npy")
    parser.add_argument('--input_latent_file', type=str, default="")
    return parser.parse_args()

class data_loader():
    def __init__(self, data_file, batch_size, split=0.95):
        f = open(data_file, "rb")
        data = pickle.load(f)
        #self.data = jnp.array(data)
        mean, var = list(zip(*data))
        self.mean = jnp.array(mean)
        self.var = jnp.array(var)
        print (self.mean.shape)
        self.dimension = len(data[0][0])
        self.data_size = len(data)
        self.batch_size = batch_size
        self.train_size = int(self.data_size*split)
        self.randomize()
        self.num_batch = self.train_size // self.batch_size
        self.start = 0

    def get_batch(self):
        key = random.PRNGKey(758493)
        idx_list = self.indices[self.start: self.start + self.batch_size]
        self.start += self.batch_size
        mean = self.mean[idx_list]
        log_var = self.var[idx_list]
        noise = random.normal(key, shape=mean.shape)
        data = mean + noise * jnp.exp(log_var/2.)
        return data

    def get_test_batch(self):
        key = random.PRNGKey(1234)
        mean = self.mean[self.train_size:, :]
        log_var = self.var[self.train_size:, :]
        noise = random.normal(key, shape=mean.shape)
        data = mean + noise * jnp.exp(log_var/2.)
        return data

    def randomize(self):
        self.indices = np.random.permutation(self.train_size)
        self.start = 0

class numpy_data_loader():
    def __init__(self, data_file, batch_size, split=0.95):
        data = np.load(data_file, allow_pickle=True)
        self.data = jnp.array(data)
        self.dimension = data.shape[1]
        self.data_size = data.shape[0]
        self.batch_size = batch_size
        self.train_size = int(self.data_size*split)
        self.randomize()
        self.num_batch = self.train_size // self.batch_size
        self.start = 0

    def get_batch(self):
        list = self.indices[self.start: self.start + self.batch_size]
        self.start += self.batch_size
        return self.data[list]

    def get_test_batch(self):
        return self.data[self.train_size:, :]

    def randomize(self):
        self.indices = np.random.permutation(self.train_size)
        self.start = 0

if __name__ == '__main__':
    args = parse_arguments()
    if args.refine:
        assert len(args.refine_type) > 0
    output_dir = make_output_dir(args.name, args.overwrite, args, None)
    split=0.95
    if args.encode:
        split = 1.
    if args.data_file.split('.')[-1] == "npy":
        data_loader = numpy_data_loader(args.data_file, args.batch_size, split=split)
    else:
        data_loader = data_loader(args.data_file, args.batch_size, split=split)
    if args.square:
        args.latent_dimension = data_loader.dimension
    model = VAEModel(
                dirname=output_dir,
                num_epoch=args.num_epochs,
                batch_size=args.batch_size,
                learning_rate=args.learning_rate,
                dataset=data_loader,
                layer_sizes=args.layer_sizes,
                encoder_layer_sizes=args.encoder_layer_sizes,
                epsilon=args.epsilon,
                tqdm=True,
                latent_dimension=args.latent_dimension,
                tunable_decoder_var=args.tunable_decoder_var,
                load_model = args.sample + args.refine,
                refine_type = args.refine_type,
                seed=args.seed)
    if not args.sample and not args.encode:
        model.train()
        samples, z = model.sample_batch(args.sample_size)
        np.save(output_dir + "/" + args.output_latent_file, samples)
    elif args.encode:
        print ("encoding only")
        latents = model.encode()
        print (latents.shape)
        np.save(output_dir + "/" + args.output_latent_file, latents)
    elif args.sample and not args.encode:
        print ("sampling only")
        latents = None
        if args.input_latent_file:
            print ("load previous stage latents")
            latents=np.load(output_dir + "/" + args.input_latent_file, allow_pickle=True)
            latents = jnp.array(latents)
        samples, z = model.sample_batch(args.sample_size, latents=latents)
        np.save( output_dir + "/" + args.output_latent_file, samples)
