

import argparse
import pathlib
import shutil
import subprocess


def make_parser():
    parser = argparse.ArgumentParser(
        description="Unshard S3 checkpoint and convert to HF format. Invoke this script from the root of the OLMo repo."
    )
    parser.add_argument("--sharded_bucket", help="S3 bucket with sharded checkpoint.", type=str)
    parser.add_argument(
        "--unsharded_bucket",
        help="S3 bucket to save the unsharded checkpoint.",
        type=str,
    )
    parser.add_argument(
        "--already_downloaded",
        action="store_true",
        help="Use this flag if the unsharded S3 checkpoint is already downloaded, but still needs to be unsharded.",
    )
    parser.add_argument(
        "--already_unsharded",
        action="store_true",
        help="If given, the checkpoint has already been unsharded; just convert to HF.",
    )
    parser.add_argument("--hf_bucket", help="S3 bucket to save the HF-converted checkpoint.", type=str)
    parser.add_argument(
        "--local_dir",
        help=,
        type=pathlib.Path,
    )
    parser.add_argument(
        "--cleanup_local_dir",
        action="store_true",
        help="If given, remove the local directory if everything runs successfully to free up space on NFS.",
    )
    parser.add_argument(
        "--checkpoint_style",
        default="hf_olmo",
        choices=["hf_olmo", "transformers"],
        help=,
    )
    parser.add_argument(
        "--hf_olmo",
        action="store_true",
        help="If given, convert to 'hf-olmo' style checkpoints.",
    )
    parser.add_argument(
        "--quiet",
        action="store_true",
        help="If given, don't show progress for AWS commands.",
    )
    parser.add_argument("--type", default=None, help="If given, pass this argument on to `unshard.py`.")
    parser.add_argument("--model_only", action="store_true", help="If given, only unshard the model.")
    return parser


def aws_copy(src, dest, args):
    base = "aws s3 sync --exclude tmp/*"
    if args.quiet:
        base += " --quiet"
    if args.type == "olmo_core" and args.model_only:
        
        base += " --exclude optim/* --exclude train/*"
    cmd = f"{base} {src} {dest}"

    return cmd


def s3_unshard_to_hf(args):
    
    sharded_dir = args.local_dir / "sharded"
    unsharded_dir = args.local_dir / "unsharded"
    if args.checkpoint_style == "hf_olmo":
        hf_dir = args.local_dir / "hf-olmo"
    elif args.checkpoint_style == "transformers":
        hf_dir = args.local_dir / "transformers"
    else:
        raise ValueError(f"Unknown checkpoint style: {args.checkpoint_style}.")
    hf_dir.mkdir(exist_ok=True)

    
    if args.already_unsharded:
        download_cmd = aws_copy(args.unsharded_bucket, unsharded_dir, args)
        subprocess.run(download_cmd, shell=True, check=True)
    else:
        if not args.already_downloaded:
            
            print("Downloading sharded checkpoint from S3.")
            download_cmd = aws_copy(args.sharded_bucket, sharded_dir, args)
            subprocess.run(download_cmd, shell=True, check=True)

        
        print("Unsharding.")
        unshard_cmd = f"python scripts/unshard.py {sharded_dir} {unsharded_dir}"
        
        if args.type is not None:
            unshard_cmd += f" --type {args.type}"
        if args.model_only:
            unshard_cmd += " --model-only"

        subprocess.run(unshard_cmd, shell=True, check=True)

    
    print("Converting to HF.")
    if args.checkpoint_style == "hf_olmo":
        
        hf_cmd = f"python hf_olmo/convert_olmo_to_hf.py --checkpoint-dir {unsharded_dir}"
        subprocess.run(hf_cmd, shell=True, check=True)
        
        for fname in [
            "config.json",
            "pytorch_model.bin",
            "special_tokens_map.json",
            "tokenizer.json",
            "tokenizer_config.json",
        ]:
            (unsharded_dir / fname).rename(hf_dir / fname)
    else:
        
        hf_cmd = f
        subprocess.run(hf_cmd, shell=True, check=True)

    
    print("Uploading files back to S3.")
    if not args.already_unsharded:
        upload_unsharded_cmd = aws_copy(unsharded_dir, args.unsharded_bucket, args)
        subprocess.run(upload_unsharded_cmd, shell=True, check=True)

    upload_hf_cmd = aws_copy(hf_dir, args.hf_bucket, args)
    subprocess.run(upload_hf_cmd, shell=True, check=True)


def main():
    parser = make_parser()
    args = parser.parse_args()
    args.local_dir.mkdir(exist_ok=True, parents=True)

    s3_unshard_to_hf(args)

    if args.cleanup_local_dir:
        
        shutil.rmtree(args.tmp_dir)


if __name__ == "__main__":
    main()
