import argparse
import d3rlpy


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='antmaze-umaze-v2')
    parser.add_argument('--num', type=int, default=1000)
    parser.add_argument('--dim', type=int, default=5)
    parser.add_argument('--size', type=float, default=1)
    '''
    # stochastic 
    [rane]  random_episodes 
    [rant]  random_transitions
    [rank]  rank 
    [raner] random_episodes_reward
    [rang] random_episodes_final_state
    !!! 
    [rewa] --------- corresponds to the "rank" decomposition rule in DROP paper ! 
    !!! 
    '''
    parser.add_argument('--type', type=str, default='rewa')# rank rewa
    parser.add_argument('--seed', type=int, default=999)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--ada', type=int, default=0)
    args = parser.parse_args()
    args.is_adaptive_iters = [None]

    # dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
    dataset, env = d3rlpy.datasets.get_d4rl(args.dataset)

    # fix seed
    d3rlpy.seed(args.seed)
    env.seed(args.seed)
    

    if args.size >= 1:
        args.size = int(args.size)
        
    drop_encoder = d3rlpy.models.encoders.VectorEncoderFactory([1024, 512, args.dim])
    drop_optim = d3rlpy.models.optimizers.AdamFactory(weight_decay=1e-4)
    actor_encoder = d3rlpy.models.encoders.VectorEncoderFactory([512, 512])
    actor_optim = d3rlpy.models.optimizers.AdamFactory(weight_decay=1e-4)
    critic_encoder = d3rlpy.models.encoders.VectorEncoderFactory([512, 512])
    critic_optim = d3rlpy.models.optimizers.AdamFactory(weight_decay=1e-4)

    drop = d3rlpy.algos.DROP(drop_num=args.num,
                             drop_dim=args.dim,
                             drop_size=args.size,
                             drop_seed=args.seed,
                             drop_type=args.type,
                             drop_learning_rate=1e-3,
                             actor_learning_rate=1e-3,
                             critic_learning_rate=1e-3,
                             drop_encoder_factory=drop_encoder,
                             drop_optim_factory=drop_optim,
                             actor_encoder_factory=actor_encoder,
                             actor_optim_factory=actor_optim,
                             critic_encoder_factory=critic_encoder,
                             critic_optim_factory=critic_optim,
                             embedding_learning_rate=1e-3,
                             energy_learning_rate=3e-4,
                             energy_update_steps=1,
                             batch_size=1024,
                             use_gpu=args.gpu, 
                             gamma=1.0,
                             tau=0.01,
                             policy_type="deterministic")

    drop.fit(dataset, #dataset.episodes,  
             eval_episodes=None,
             n_steps=500000, # 50000 
             n_steps_per_epoch=500,
             save_interval=10,
             is_adaptive=bool(args.ada),
             is_adaptive_iters=args.is_adaptive_iters,
             scorers={
                 'environment': d3rlpy.metrics.evaluate_on_environment(env), 
             },
             experiment_name=f"DROP_{args.dataset}_{args.type}_{args.ada}_{args.dim}_{args.size}_{args.num}_{args.seed}",
             action_space=env.action_space,
             observation_space=env.observation_space,
             save_env = True if "antmaze" in args.dataset else False,
             env_to_be_save = env if "antmaze" in args.dataset else None)


if __name__ == '__main__':
    main()
