from lightning.pytorch.cli import LightningCLI
from datamodules import Adult, Crime, Compas, Health
from fairmodel import FairClassifier, FairClassifierSimple

if __name__ == "__main__":
    cli = LightningCLI(
        run=False,
        save_config_callback=None,
        trainer_defaults=dict(deterministic=True),
    )
    cli.trainer.fit(cli.model, cli.datamodule)
    cli.trainer.test(cli.model, cli.datamodule)
