import os
import jax
import wandb

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "True"
jax.config.update("jax_enable_x64", True)


from data import load_ds
from configs import get_config
from trainer import AdvTrainer



def main():
    config=get_config()
    wandb_config = config.wandb
    wandb.init(project=wandb_config.project, name=wandb_config.name)
    (train_ds, test_ds),(force_scale, tensor_scale) = load_ds(config.dataset)
    trainer = AdvTrainer(config)
    _ = trainer.train(train_ds, test_ds, force_scale=force_scale, tensor_scale=tensor_scale)



if __name__=="__main__":
    main()
