import jax 
import os

import wandb

from configs import get_config
from trainer import Trainer
from data import load_ds


os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "True"
jax.config.update("jax_enable_x64", True)


def main():
    config=get_config()
    wandb_config = config.wandb
    wandb.init(project=wandb_config.project, name=wandb_config.name)
    trainer=Trainer(config)
    train_ds, test_ds=load_ds(config.dataset)
    _=trainer.train(train_ds, test_ds)


if __name__=="__main__":
    main()
