import argparse
from stable_baselines3.sac import SACBrm, SACBrmNB, SACBrmSplit
import torch
import time
from distutils.util import strtobool


mbpo_target_entropy_dict = {'Hopper-v4':-1., 'HalfCheetah-v4':-3., 'Walker2d-v4':-3., 'Ant-v4':-4., 'Humanoid-v4':-2.}


def learn(env_name, seed, logging_path, max_steps, w_mult, w_brm, use_target_brm, mbpo_entrop, tau, target_update_interval, mode, nb_hidden_q):
    torch.set_num_threads(1)
    te = "auto"
    if mbpo_entrop:
        try:
            te = mbpo_target_entropy_dict[env_name]
        except KeyError:
            pass
    pol_kwargs = {'net_arch': dict(qf=[256] * nb_hidden_q, pi=[256] * 2)}
    if mode == 0:
        SACBrm('MlpPolicy', env_name, learning_starts=10000, tensorboard_log=logging_path, device='cpu', seed=seed,
               weight_lr_mult=w_mult, weight_brm=w_brm, use_target_w_brm=use_target_brm, target_entropy=te, tau=tau,
               target_update_interval=target_update_interval, policy_kwargs=pol_kwargs).learn(max_steps)
    elif mode == 1:
        SACBrmNB('MlpPolicy', env_name, learning_starts=10000, tensorboard_log=logging_path, device='cpu', seed=seed,
                 weight_lr_mult=w_mult, weight_brm=w_brm, use_target_w_brm=use_target_brm, target_entropy=te, tau=tau,
                 target_update_interval=target_update_interval, policy_kwargs=pol_kwargs).learn(max_steps)
    else:
        SACBrmSplit('MlpPolicy', env_name, learning_starts=10000, tensorboard_log=logging_path, device='cpu', seed=seed,
                    weight_lr_mult=w_mult, weight_brm=w_brm, target_entropy=te, tau=tau,
                    target_update_interval=target_update_interval, policy_kwargs=pol_kwargs).learn(max_steps)


if __name__ == '__main__':
    argp = argparse.ArgumentParser()
    argp.add_argument('--env-name', type=str, default='HalfCheetah-v4')
    argp.add_argument('--seed', type=int, default=0)
    argp.add_argument('--logging-path', type=str, default='')
    argp.add_argument('--max-steps', type=int, default=1_000_000)
    argp.add_argument('--w-mult', type=float, default=1.)
    argp.add_argument('--use-target-brm', type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)
    argp.add_argument('--mode', type=int, default=1)  # 0: brm, 1: brm no bias, 2: brm split
    argp.add_argument('--nb-hidden-q', type=int, default=2)
    argp.add_argument('--mbpo-entrop', type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True)
    argp.add_argument('--w-brm', type=float, default=1.)
    argp.add_argument('--target-update-interval', type=int, default=10)
    argp.add_argument('--tau', type=float, default=1.)
    args = argp.parse_args()
    paras = args.__dict__
    if args.logging_path == '':
        paras['logging_path'] = f'./runs/{args.env_name}_{args.seed}_{int(time.time())}'
        print('log_dir', paras['logging_path'])
    learn(**paras)
