import numpy as np
import torch
import torch.nn as nn
import gym
import hirid_env
import sepsis_env

import argparse
import os
import d4rl
import numpy as np

from delphicORL.utils import utils, data, logging, scorers


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    # Experiment
    parser.add_argument("--command_name", default="ope/dm")      # OPE name
    parser.add_argument("--env", default="hopper-medium-v0")        # OpenAI gym environment name
    parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--eval_freq", default=5e3, type=int)       # How often (time steps) we evaluate
    parser.add_argument("--max_timesteps", default=1e6, type=int)   # Max time steps to run environment

    # Logging
    parser.add_argument("--extra_log_rep", type=str, default = "")        

    args = parser.parse_args()        

    print("---------------------------------------")
    print(f"OPE Type: {args.command_name}, Env: {args.env}, Seed: {args.seed}")
    print("---------------------------------------")

    custom_logger, log_dir = logging.setup_logging(args)
    env = gym.make(args.env)
    utils.set_seeds(args.seed, env)
    expert_trajs = data.get_imitation_dataset(args.env) #, n_max_demos=args.n_trajs
    expert_trajs, test_trajs = data.split_datasets(expert_trajs)
    
    # cont obs space, discret action space
    if args.command_name in ['ope/dm', 'ope/dr']:
        reward_fn = scorers.RewardFn(observation_space=env.observation_space,
                    action_space=env.action_space,
                    demonstrations=expert_trajs,
                    test_demonstrations=test_trajs,
                    custom_logger=custom_logger,)
                    
        reward_fn.fit( n_epochs=20, save_to = log_dir)

    if args.command_name in ['ope/ipw', 'ope/dr']:
        propensity_fn = scorers.train_propensity_function(expert_trajs,          
                             obs_dim = env.observation_space[0], 
                             act_dim = len(env.action_space),
                             save_to = log_dir)

 