import tensorflow as tf
from algo.Algo import Algo
from datetime import datetime
import os
from config.config import define_config
import argparse

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', action="store", type=str, default='walker_walk')
    parser.add_argument('--dataset_type', action="store", type=str)
    parser.add_argument('--latent_algo', action="store", type=str)
    parser.add_argument('--risky_env', action="store", type=bool, default=True)
    parser.add_argument('--load_pretrained_latent', action='store', type=bool, default=True)
    parser.add_argument('--load_latent_real_dataset', action='store', type=bool, default=True)
    parser.add_argument('--load_imit_actor', action='store', type=bool, default=False)
    parser.add_argument('--evaluate', action='store', type=bool, default=True)
    parser.add_argument('--num_evaluations', action='store', type=int, default=100)
    args = parser.parse_args()

    return args


if __name__ == '__main__':
    date = datetime.now().strftime("%Y_%m_%d_%H%M%S_%f")

    for gpu in tf.config.experimental.list_physical_devices('GPU'):
        tf.config.experimental.set_memory_growth(gpu, True)
    args = parse_args()
    config = define_config(args)
    if not args.evaluate:
        # initialization of the algorithm
        algo = Algo(config)

        # Train or load trained latent variable model
        if args.load_pretrained_latent:
            algo.load_latent_model(config.trained_latent_dir / 'final_latent_model')
            print("We loaded a pretrained latent variable model.")
        else:
            algo.latent_model_training(config.num_train_step_latent_model)

        # Load latent dataset or build it
        if args.load_latent_real_dataset:
            algo.latent_buffer.load(config.trained_latent_dir / 'buffer.h5py')
            print("We loaded the latent buffer.")
        else:
            algo.process_data_to_latent(print_process=True)
            algo.latent_buffer.save(config.trained_latent_dir / 'buffer.h5py')

        # Actor critic training
        algo.train()
        algo._actor_critic.save_actor_critic(config.logdir / 'final_agent')

    if args.evaluate:
        print("We are going to evaluate a trained model.")
        # initialization of the algorithm
        algo = Algo(config)

        algo.load_latent_model(config.trained_latent_dir / 'final_latent_model')
        print("We loaded the trained latent variable model.")

        algo.load_agent(config.logdir / 'final_agent')
        algo.evaluate(episodes=args.num_evaluations)











