import argparse
from stable_baselines3.dqn import DQN
import torch
import time



def learn(env_name, seed, logging_path, max_steps, w_brm, target_update_interval):
    torch.set_num_threads(1)
    pol_kwargs = {'net_arch': [256] * 3}
    DQN('MlpPolicy', env_name, tensorboard_log=logging_path, device='cpu', seed=seed,
        learning_starts=10000, learning_rate=3e-4, train_freq=1, target_update_interval=target_update_interval,
        policy_kwargs=pol_kwargs, w_brm=w_brm).learn(max_steps)

    # pol_kwargs = {'normalize_images': False}
    # DQN('CnnPolicy', env_name, tensorboard_log=logging_path, device='cpu', seed=seed,
    #     learning_starts=10000, learning_rate=3e-4, train_freq=2, target_update_interval=1,
    #     policy_kwargs=pol_kwargs, w_brm=w_brm).learn(max_steps)


if __name__ == '__main__':
    argp = argparse.ArgumentParser()
    # argp.add_argument('--env-name', type=str, default='MinAtar/Asterix-v1')
    # argp.add_argument('--env-name', type=str, default='MinAtar/SpaceInvaders-v1')
    # argp.add_argument('--env-name', type=str, default='MinAtar/Freeway-v1')
    # argp.add_argument('--env-name', type=str, default='MinAtar/Breakout-v1')
    argp.add_argument('--env-name', type=str, default='LunarLander-v2')
    argp.add_argument('--seed', type=int, default=1)
    argp.add_argument('--logging-path', type=str, default='')
    argp.add_argument('--max-steps', type=int, default=100_000)
    argp.add_argument('--w-brm', type=float, default=0.2)
    argp.add_argument('--target-update-interval', type=int, default=10)

    args = argp.parse_args()
    paras = args.__dict__
    if args.logging_path == '':
        paras['logging_path'] = f'./runs/{args.env_name}_{args.seed}_{int(time.time())}'
        print('log_dir', paras['logging_path'])
    learn(**paras)
