"""Training and evaluation on a biased-exposure dataset.

Adapted from github.com/deepmind/dm-haiku/blob/master/examples/mnist.py.
"""

from absl import app
from absl import flags
from absl import logging

import os
import sys

import gin

import tensorflow as tf

from exposure_bias import data
from exposure_bias import models
from exposure_bias import train


DEFAULT_DIR = "/tmp/exposure_bias/"


flags.DEFINE_string(
  "checkpoint_dir",
  os.path.join(DEFAULT_DIR, "checkpoints"),
  "Directory in which to store model checkpoints.",
)

flags.DEFINE_string(
  "summary_dir",
  os.path.join(DEFAULT_DIR, "summaries"),
  "Directory in which to store summaries.",
)

flags.DEFINE_multi_string(
  "gin_config", None, "List of paths to the Gin config files."
)
flags.DEFINE_multi_string(
  "gin_binding", None, "List of Gin parameter bindings."
)
FLAGS = flags.FLAGS


def main(_):
  gin.parse_config_files_and_bindings(
    FLAGS.gin_config, FLAGS.gin_binding, finalize_config=True
  )

  logging.info(
    f"Writing checkpoints to {FLAGS.checkpoint_dir} "
    f"and summaries to {FLAGS.summary_dir}."
  )

  try:
    train_data, valid_data, test_data = train.load_datasets()
    params, state = train.train(
      train_dataset=train_data,
      valid_dataset=valid_data,
      test_dataset=test_data,
    )
    train.evaluate(
      params=params,
      state=state,
      test_dataset=test_data,
    )

  except ValueError as e:
    logging.info("Full Gin configurations:\n%s", gin.config_str())
    raise e


if __name__ == "__main__":
  logging.get_absl_handler().python_handler.stream = sys.stdout
  app.run(main)
