"""
A script to export model files and upload to huggingface.
"""
import os
import argparse
import tarfile

from huggingface_hub import HfApi


def push_model(
    upload_path: str,
    entity: str = "nan",
    repo_name: str = "implicit_nonlinear_dynamics",
    branch: str = "main",
):
    """
    Uploads model to hugging face repository.
    """
    api = HfApi()

    # ensure repo exists
    repo_id = f"{entity}/{repo_name}"
    api.create_repo(
        repo_id=repo_id,
        exist_ok=True,
        private=True,
    )

    # create model branch
    api.create_branch(
        repo_id=repo_id,
        branch=branch,
        repo_type="model",
        exist_ok=True,
    )

    # upload requested files
    if os.path.isdir(upload_path):
        api.upload_folder(
            folder_path=upload_path,
            repo_id=repo_id,
            repo_type="model",
            multi_commits=True,
            revision=branch,
        )
    elif os.path.isfile(upload_path):
        api.upload_file(
            path_or_fileobj=upload_path,
            path_in_repo=upload_path.split("/")[-1],
            repo_id=repo_id,
            repo_type="model",
            revision=branch,
        )
    else:
        raise NotImplementedError


def main():
    # Create the parser
    parser = argparse.ArgumentParser(description="")

    # Add arguments
    parser.add_argument("entity", type=str, help="The entity to upload the model to")
    parser.add_argument("repo", type=str, help="The repository to upload the model to")
    parser.add_argument("branch", type=str, help="The branch to upload the model to")
    parser.add_argument("model_save_dir", type=str, help="The directory to save the model files")

    # Parse the arguments
    args = parser.parse_args()

    # upload to hugging face
    print("Compressing model files...")
    compressed_file = os.path.join(args.model_save_dir, "checkpoint.tar.xz")
    with tarfile.open(compressed_file, "w:xz") as tar:
        tar.add(args.model_save_dir, arcname=".")

    print("Uploading model to hugging face...")
    push_model(
        upload_path=compressed_file,
        entity=args.entity,
        repo_name=args.repo,
        branch=args.branch,
    )


if __name__ == "__main__":
    main()
