
import bench, logger

def train(env_id, num_timesteps, seed, policy, lr_ac, lr_cr, ircr, reg, scale_reward,
          policy_freq, add_t, reward_n, drop_type, drop_r, rand_r, buffer_size):

    from common import set_global_seeds
    import sac
    from policies import MlpPolicy
    import gym
    import tensorflow as tf

    config = tf.ConfigProto(allow_soft_placement=True,intra_op_parallelism_threads=1,inter_op_parallelism_threads=1)

    config.gpu_options.allow_growth = True
    tf.Session(config=config).__enter__()

    train_env = gym.make(env_id)
    train_env = bench.Monitor(train_env, logger.get_dir())
    test_env = gym.make(env_id)
    test_env = bench.Monitor(test_env, logger.get_dir())

    set_global_seeds(seed)

    if policy == 'mlp':
        policy = MlpPolicy
    else:
        raise NotImplementedError

    sac.learn(policy=policy,
        train_env=train_env,
        test_env=test_env,
        lr_ac=lr_ac,
        lr_cr=lr_cr,
        ircr=ircr,
        reg=reg,
        scale_reward=scale_reward,
        policy_freq=policy_freq,
        total_timesteps=num_timesteps,
        reward_n=reward_n,
        drop_type=drop_type,
        drop_r=drop_r,
        rand_r=rand_r,
        add_t=add_t,
        buffer_size=buffer_size)


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', help='Environment ID', default='Hopper-v2')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--policy', help='Policy architecture', choices=['mlp'], default='mlp')
    parser.add_argument('--num-timesteps', type=int, default=int(1E6))
    parser.add_argument('--ircr', type=int, default=0)
    parser.add_argument('--add-t', type=int, default=1)
    parser.add_argument('--reg', type=float, default=0.5)
    parser.add_argument('--lr-ac', type=float, default=3E-4)
    parser.add_argument('--lr-cr', type=float, default=3E-4)
    parser.add_argument('--policy-freq', type=int, default=1)
    parser.add_argument('--scale-reward', type=float, default=1.)
    # env parameters
    parser.add_argument('--reward-n', type=int, default=20)
    parser.add_argument('--drop-type', choices=['Sum'], default='Sum')
    parser.add_argument('--drop-pr', type=float, default=0.0)
    parser.add_argument('--delay-pr', type=float, default=0.0)
    args = parser.parse_args()

    # define the logging dir
    task = args.drop_type + (str(args.drop_pr) if args.drop_pr > 1E-2 else '')
    task = task + '_' + 'delay' + str(args.delay_pr) + '_n' +  str(int(args.reward_n))

    alg = 'ircr_' if bool(args.ircr) else ''
    alg = alg + ('t' if bool(args.add_t) else 'n') + '_reg' + str(args.reg)

    logger.configure(dir='./data/' + task + '/HC-Pairwise-1/' + alg + '/' + args.env + '/' + str(args.seed))

    train(env_id=args.env, num_timesteps=args.num_timesteps, seed=args.seed, policy=args.policy, policy_freq=args.policy_freq,
          add_t=bool(args.add_t), ircr=bool(args.ircr), lr_ac=args.lr_ac, lr_cr=args.lr_cr, reg=args.reg,  scale_reward=args.scale_reward,
          buffer_size=int(2e6), reward_n=args.reward_n, drop_type=args.drop_type, drop_r=args.drop_pr, rand_r=args.delay_pr)


if __name__ == '__main__':
    main()
