import os
import jax

import wandb


os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "True"
jax.config.update("jax_enable_x64", True)

from data import load_ds
from configs import get_config
from trainer import Trainer


def main():
    config=get_config()
    wandb_config = config.wandb
    wandb.init(project=wandb_config.project, name=wandb_config.name)
    (train_ds,test_ds),(force_scale,tensor_scale)=load_ds(config.dataset)
    trainer=Trainer(config)
    _=trainer.train(train_ds,test_ds,force_scale,tensor_scale)

if __name__=="__main__":
    main()