from nesim.utils.moving_checkpoints import (
    download_checkpoint_from_remote,
    upload_checkpoint_to_remote,
)
import argparse

# Initialize the parser
parser = argparse.ArgumentParser(
    description="Move checkpoints from one machine to another"
)
parser.add_argument("--source", type=str, required=True, help="The starting string")
parser.add_argument("--destination", type=str, required=True, help="The destination string")
parser.add_argument(
    "--topo-scale", type= int, required=True, help="either some number or 'baseline'"
)
parser.add_argument("--global-step", type=int, required=True, help="global step number")

args = parser.parse_args()
if args.topo_scale == "baseline":
    args.topo_scale = 0

machines = {
    "penfield": {
        "hostname": "penfield.psych.XXXX-7.edu",
        "port": 22,
        "username": "XXXX-4",
        "checkpoints_dir": "/home/XXXX-4/repos/nesim/training/gpt_neo_125m/checkpoints",
    },
    "barlow": {
        "hostname": "barlow.psych.XXXX-7.edu",
        "port": 2222,
        "username": "XXXX-4",
        "checkpoints_dir": "/research/XXXX-4/nesim/training/gpt_neo_125m/checkpoints",
    },
    "pace": {
        "hostname": "login-phoenix-rh9.pace.XXXX-7.edu",
        "port": 22,
        "username": "XXXX-4",
        "checkpoints_dir": "/storage/home/hcoda1/4/XXXX-4/p-XXXX-6-0/repos/nesim/training/gpt_neo_125m/checkpoints",
    },
}

checkpoint_path = download_checkpoint_from_remote(
    machines=machines,
    machine_name=args.source,
    cache_dir="./cache",
    topo_scale=args.topo_scale,
    global_step=args.global_step,
    run=True,
)

checkpoint_path_remote = upload_checkpoint_to_remote(
    machines=machines,
    machine_name=args.destination,
    topo_scale=args.topo_scale,
    global_step=args.global_step,
    checkpoint_path=checkpoint_path,
    run=True,
    password="PASSWORD"
)
print(f"checkpoint_path_remote: {checkpoint_path_remote}")

"""Examples
python move_checkpoints.py --source  barlow --destination penfield --topo-scale 0 --global-step 800

python move_checkpoints.py --source pace --destination penfield --topo-scale 50 --global-step 800
"""
