import numpy as np
import torch
import gym
import hirid_env
import sepsis_env
import argparse
import os
import d4rl

from delphicORL.utils import utils, logging, data
from delphicORL.algos.confounding.worldmodel_learner import WorldModelLearner


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    # Experiment
    parser.add_argument("--command_name", default="worldmodels")   
    parser.add_argument("--env", default="hopper-medium-v0")        # OpenAI gym environment name
    parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--eval_freq", default=5e3, type=int)       # How often (time steps) we evaluate
    parser.add_argument("--max_timesteps", default=1e6, type=int)   # Max time steps to run environment
    parser.add_argument("--lstm", type=str, default='gru')
    parser.add_argument("--normalize", default=True)
    parser.add_argument("--extra_log_rep", type=str, default = "")

    args = parser.parse_args()

    print("---------------------------------------")
    print(f"Model: {args.command_name}, Env: {args.env}, Seed: {args.seed}")
    print("---------------------------------------")

    custom_logger, log_dir = logging.setup_logging(args)
    env = gym.make(args.env)
    utils.set_seeds(args.seed, env)

    expert_trajs = data.get_imitation_dataset(args.env, n_max_demos=args.n_trajs)
    expert_trajs, test_trajs = data.split_datasets(expert_trajs)

    wmlearner = WorldModelLearner(
            observation_space=env.observation_space,
            action_space=env.action_space,
            demonstrations=expert_trajs,
            test_demonstrations=test_trajs,
            custom_logger=custom_logger,
            batch_size=32,
            optimizer_cls=torch.optim.Adam,
            lstm = True, #args.lstm is not None,
            optimizer_kwargs=dict(lr=4e-4),
            max_len=20,
            wm_klweight=1e-2,
            wm_target_dim=1,
        )

    wmlearner.train(
        n_epochs=50, 
        log_interval=5,
    )
    wmlearner.save_clvae(vae_path=os.path.join(log_dir, "final.th"))