import os

from absl import app, flags

FLAGS = flags.FLAGS

flags.DEFINE_string("path", None, "Path to checkpoint directory")
flags.DEFINE_string("checkpoint_step", None, "Step number")
flags.DEFINE_string("save_path", ".", "Path to save checkpoint")


def main(_):
    if FLAGS.path.startswith("gs"):
        cmd = "gsutil cp"
    elif "@" in FLAGS.path:
        cmd = "scp "
    else:
        cmd = "cp "

    checkpoint_path = os.path.join(FLAGS.path, f"{FLAGS.checkpoint_step}")
    norm_path = os.path.join(FLAGS.path, "dataset_statistics*")
    config_path = os.path.join(FLAGS.path, "config*")
    example_batch_path = os.path.join(FLAGS.path, "example_batch.msgpack*")
    run_name = os.path.basename(os.path.normpath(FLAGS.path))
    save_path = os.path.join(FLAGS.save_path, run_name)

    # Construct the save path
    os.makedirs(save_path, exist_ok=True)

    # Call the command to c opy the relevant files over.
    os.system(f"{cmd} -r {checkpoint_path} {save_path}/")
    os.system(f"{cmd} {norm_path} {save_path}/")
    os.system(f"{cmd} {config_path} {save_path}/")
    os.system(f"{cmd} {example_batch_path} {save_path}/")


if __name__ == "__main__":
    app.run(main)
