from LogRegIRL import DeepLogRegIRL
import numpy as np
import argparse

print("### Deep Logistic Regression-Based IRL by Chainer ###")

parser = argparse.ArgumentParser(
            prog='Deep Logistic Regression-Based IRL (LogReg-IRL)',
            usage='Estimating reward and state-value by Deep LogReg-IRL.',
            description='#### argument description ####',
            epilog='##############################',
            add_help=True,
            )

# ArgumentParser
parser.add_argument('--exp', default="exp.csv", help='Data name of Expert trajectories.')
parser.add_argument('--base', default="base.csv", help='Data name of Baseline trajectories.')
parser.add_argument('--n_hidden_f', default=1024, help='Number of units in hidden layers for f-network.')
parser.add_argument('--n_hidden_qv', default=1024, help='Number of units in hidden layers for q-/V-network.')
parser.add_argument('--n_epoch_f', default=1000, help='Number of training epochs for f-network.')
parser.add_argument('--n_epoch_qv', default=1000, help='Number of training epochs for q-/V-network.')
parser.add_argument('--n_step_f', default=100, help='Number of steps in one training epoch for f-network.')
parser.add_argument('--n_step_qv', default=100, help='Number of steps in one training epoch for q-/V-network.')
parser.add_argument('--batch_size', default=1024, help='Size of minibatch.')
parser.add_argument('--gpu', default=0, help='Device ID of GPU. For CPU, use --gpu -1')
parser.add_argument('--train_model', default="./model_LogReg/", help='Path for saving models of training.')
parser.add_argument('--save_model', default="./saved_model", help='Path for saving best models.')
parser.add_argument('--gamma', default=0.99, help='Discount rate in MDP (0 < gamma <= 1).')

print("Reading arguments...")
# Read arguments
args = parser.parse_args()

print("... done.")

print("Launching Deep LogReg-IRL...")

DeepLogRegIRL(args.exp, args.base,
              n_hidden_f=args.n_hidden_f, n_hidden_qv=args.n_hidden_qv,
              n_epoch1=args.n_epoch_f, n_epoch2=args.n_epoch_qv,
              n_step1=args.n_step_f, n_step2=args.n_step_qv,
              gamma=args.gamma, batch_size=args.batch_size, gpu_device=args.gpu,
              model_path=args.train_model,save_path=args.save_model)

print("### Finish! ###")
