from agent import *
from envs import *
import argparse
from warnings import filterwarnings
filterwarnings(
    action="ignore",
    category=DeprecationWarning,
    message="`np.bool8` is a deprecated alias for `np.bool_`",
)

if __name__ == "__main__":
    args = read_args()
    setup_logging(args.log_level) 
    torch.set_num_threads(torch.get_num_threads())
    model = build_agent(args)
    model.train()

# python main.py --model="mop" --exp_name="train_to_delete" --env="ant" --ac_model=mlp --layers=256,256,256
# python main.py --model="mop" --exp_name="debug" --env="ant" --ac_model=vae --vq            --vq_num_emb=32  --vq_cd=2.0 --vq_beta=0.1 --vq_lambda_reg=0.1 --vq_temp=1.0 --bs=32  --layers=16    --log_level=DEBUG --dropout_rate=0.1
# python main.py --model="mop" --exp_name="vqvae" --env="ant" --ac_model=vae --vq  --vq_hard --vq_num_emb=256 --vq_cd=1.0 --vq_beta=0.1 --vq_lambda_reg=0.1 --vq_temp=1.0 --bs=32  --layers=64,64 --log_level=DEBUG --dropout_rate=0.1
# python main.py --model="mop" --exp_name="mm" --env="ant" --ac_model=mm  --bs=32  --layers=64,64,64 --log_level=DEBUG --dropout_rate=0.1 --n_components=3 --repulsion_beta=2.0