from typing import Optional
from vmoe import multihost_utils
import os, orbax, jax, tensorflow as tf  # noqa: E401


def create_checkpoint_manager(
    *,
    workdir: str,
    every_steps: int,
    keep_last: Optional[int] = None,
    keep_steps_multiple_of: Optional[int] = None,
    wait_seconds: int = 300,
) -> orbax.checkpoint.CheckpointManager:
  """Creates an Orbax checkpoint manager."""
  directory = os.path.join(workdir, 'ckpt')
  if jax.process_index() == 0 and not tf.io.gfile.exists(directory):
    tf.io.gfile.makedirs(directory)
  multihost_utils.sync_devices('create-ckpt-dir')
  ckpt_options = orbax.checkpoint.CheckpointManagerOptions(
      save_interval_steps=every_steps,
      max_to_keep=keep_last,
      keep_period=keep_steps_multiple_of,
  )
  ckpt_manager = orbax.checkpoint.CheckpointManager(
      directory,
      {
          'state': orbax.checkpoint.AsyncCheckpointer(
              orbax.checkpoint.PyTreeCheckpointHandler(),
              timeout_secs=wait_seconds,
          ),
          'dataset_iterator': orbax.checkpoint.Checkpointer(
              orbax.checkpoint.JsonCheckpointHandler()
          ),
      },
      options=ckpt_options,
  )
  return ckpt_manager


def main():
    ckpt_manager = create_checkpoint_manager(
       workdir="prior_checkpoints/test_checkpoint",
       every_steps=1000,
       keep_last=3,
    )

if __name__ == '__main__':
    main()