import argparse
import os

import torch
import gym

from delphicORL.utils import utils, logging, data, scorers
from delphicORL.algos.imitation.bc import BC

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    # Experiment
    parser.add_argument("--command_name", default="imitation/BC")   # Policy name
    parser.add_argument("--env", default="hopper-medium-v0")        # OpenAI gym environment name
    parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--eval_freq", default=5e3, type=int)       # How often (time steps) we evaluate
    parser.add_argument("--lstm", type=str)          
    parser.add_argument("--fully_observable", action='store_true')
    parser.add_argument("--ope", action='store_true')
    parser.add_argument("--load_model", default="")                 # Model load file name, "" doesn't load, "default" uses file_name
    
    # Logging
    parser.add_argument("--extra_log_rep", type=str, default = "")                 
    parser.add_argument("--normalize", default=True)
    args = parser.parse_args()

    bc_kwargs = dict(
            batch_size=32 if args.lstm is None else 2,
            l2_weight=0,
            optimizer_cls=torch.optim.Adam,
            lstm = args.lstm is not None,
            lstm_model = args.lstm,
            optimizer_kwargs={'lr':4e-4},
    )

    bc_train_kwargs = dict(
        n_epochs=100, # if args.lstm is None else 5_000, #n_batches=50_000 if args.lstm is None else 5_000,
        log_interval=5, #500 if args.lstm is None else 10,  # Number of updates between Tensorboard/stdout logs
    )

    if 'hirid' in args.env:
        bc_train_kwargs['n_epochs']=200 # n_bacthes 200_000
        bc_train_kwargs['log_rollouts_n_episodes'] = 0


    if args.lstm is not None:
        args.command_name += f'_{args.lstm}'
    if args.fully_observable:
        args.command_name += '_FO'
        args.env = args.env.replace('po', 'fo')
    print("---------------------------------------")
    print(f"Policy: {args.command_name}, Env: {args.env}, Seed: {args.seed}")
    print("---------------------------------------")


    custom_logger, log_dir = logging.setup_logging(args)
    env = gym.make(args.env)

    expert_trajs = data.get_imitation_dataset(args.env)
    expert_trajs, test_trajs = data.split_datasets(expert_trajs)

    bc_trainer = BC(
            observation_space=env.observation_space,
            action_space=env.action_space,
            demonstrations=expert_trajs,
            test_demonstrations=test_trajs,
            custom_logger=custom_logger,
            ope = args.ope,
            **bc_kwargs,
        )
    if args.ope:
        bc_trainer.env = env

    bc_trainer.train(log_rollouts_venv=env, **bc_train_kwargs)
    bc_trainer.save_policy(policy_path=os.path.join(log_dir, "final.th"))