import os 
import jax

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "True"
jax.config.update("jax_enable_x64", True)

from data import load_ds
from configs import get_config
from trainer import AdvTrainer
import wandb




def main():
    config=get_config()
    wandb_config = config.wandb
    wandb.init(project=wandb_config.project, name=wandb_config.name)
    train_ds, test_ds = load_ds(config.dataset)
    trainer = AdvTrainer(config)
    _ = trainer.train(train_ds , test_ds)


if __name__=="__main__":
    main()
