import argparse
from stable_baselines3.sac import SAC
import torch
import time

mbpo_target_entropy_dict = {'Hopper-v4':-1., 'HalfCheetah-v4':-3., 'Walker2d-v4':-3., 'Ant-v4':-4., 'Humanoid-v4':-2.}


def learn(env_name, seed, logging_path, max_steps, nb_hidden_q):
    torch.set_num_threads(1)
    try:
        te = mbpo_target_entropy_dict[env_name]
    except KeyError:
        te = "auto"

    pol_kwargs = {'net_arch': dict(qf=[256] * nb_hidden_q, pi=[256] * 2)}
    SAC('MlpPolicy', env_name, learning_starts=10000, tensorboard_log=logging_path, device='cpu', seed=seed,
        target_entropy=te, policy_kwargs=pol_kwargs).learn(max_steps)


if __name__ == '__main__':
    argp = argparse.ArgumentParser()
    argp.add_argument('--env-name', type=str, default='Humanoid-v4')
    argp.add_argument('--seed', type=int, default=0)
    argp.add_argument('--logging-path', type=str, default='')
    argp.add_argument('--max-steps', type=int, default=500_000)
    argp.add_argument('--nb-hidden-q', type=int, default=2)

    args = argp.parse_args()
    paras = args.__dict__
    if args.logging_path == '':
        paras['logging_path'] = f'./runs/{args.env_name}_{args.seed}_{int(time.time())}'
        print('log_dir', paras['logging_path'])
    learn(**paras)
