import fiddle as fdl
from absl import app
from fiddle import absl_flags as fdl_flags

import tabular_mvdrl.trainer  # noqa: F401
from tabular_mvdrl.utils import printing as custom_printing

_CONFIG = fdl_flags.DEFINE_fiddle_config(
    "trainer",
    default=None,
    help_string="Fiddle config module for trainer",
)


def main(_):
    config = _CONFIG.value
    trainer = fdl.build(config)
    hparams = custom_printing.as_dict(config, flatten_tree=True)
    hparams["trainer"] = type(trainer).__name__
    trainer.metric_writer.write_hparams(hparams)
    trainer.train()


if __name__ == "__main__":
    app.run(main)
