import argparse
from all2.environments import AtariEnvironment
from all2.experiments import run_experiment
from preset import jointdqn
import torch
import numpy as np


def set_seed(seed):    
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def run():
    # parse arguments
    parser = argparse.ArgumentParser(description="Run an Atari benchmark.")
    parser.add_argument("env", default="Pong", help="Name of the Atari game (e.g. Pong)")
    parser.add_argument(
        "--device",
        default="cuda:0",
        help="The name of the device to run the agent on (e.g. cpu, cuda, cuda:0)",
    )
    parser.add_argument(
        "--frames", type=int, default=5e7, help="The number of training frames"
    )

    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()

    # create atari environment
    env = AtariEnvironment(args.env, device=args.device)
    print(f"Running with seed {args.seed}")
    env.seed(args.seed)
    set_seed(args.seed)

    # run the experiment
    run_experiment(jointdqn.device(args.device), env, args.frames)


if __name__ == "__main__":
    run()
