from huggingface_hub import HfApi, HfFolder, Repository
import os
import json
import requests
import shutil
# from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError

from verl.model_merger.base_model_merger import ModelMergerConfig
from verl.model_merger.fsdp_model_merger import FSDPModelMerger

def load_file(file_path: str):
    if os.path.exists(file_path):
        with open(file_path, "r", encoding="utf-8") as f:
                return json.load(f)
    else:
        return {}

def check_model_size(model_path: str):
    if "1.7b" in model_path:
        model_size = "1.7B"
    elif "4b" in model_path:
        model_size = "4B"
    elif "8b" in model_path:
        model_size = "8B"
    else:
        model_size = "None"
    return model_size

def upload_to_huggingface(model_path: str, repo_name: str, hf_token: str, do_merge: bool, do_upload: bool):
    if do_upload:
        api = HfApi(token=hf_token)

        api.create_repo(repo_id=repo_name, repo_type="model",
                        private=False, exist_ok=True, token=hf_token)
    
    if do_merge:
        model_size = check_model_size(model_path)
        ref_model_config_dir = f"rllm/trainer/model_config_ori/{model_size}"

        ref_config = {
            "config": load_file(os.path.join(ref_model_config_dir, "config.json")),
            "generation_config": load_file(os.path.join(ref_model_config_dir, "generation_config.json")),
            "tokenizer_config": load_file(os.path.join(ref_model_config_dir, "tokenizer_config.json")),
        }

        current_config = {
            "config": load_file(os.path.join(model_path, "config.json")),
            "generation_config": load_file(os.path.join(model_path, "generation_config.json")),
            "tokenizer_config": load_file(os.path.join(model_path, "tokenizer_config.json")),
        }

        for key in current_config:
            # current_config[key].update(ref_config[key])
            with open(os.path.join(model_path, f"{key}.json"), "w", encoding="utf-8") as f:
                json.dump(ref_config[key], f, ensure_ascii=False, indent=4)

    if do_upload:
        api.upload_folder(
            folder_path=model_path,
            repo_id=repo_name,
            repo_type="model",
            token=hf_token,
            commit_message="Upload model and configs",
        )


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="./trained_model")
    parser.add_argument("--repo_name", type=str, default="username/my-awesome-model")
    parser.add_argument("--hf_token", type=str)
    parser.add_argument("--no_merge", action="store_true")
    parser.add_argument("--do_upload", action="store_true")
    args = parser.parse_args()

    model_path = args.model_path
    repo_name = args.repo_name
    hf_token = args.hf_token

    merged_model_path = os.path.join(model_path, "merged")
    has_merged = os.path.exists(merged_model_path)

    if args.no_merge:
        do_merge_config = False
    else:
        if not has_merged:
            do_merge_config = True
        else:
            do_merge_config = False

    if not has_merged:
        merge_config = ModelMergerConfig(
            operation="merge",
            backend="fsdp",
            local_dir=model_path,
            hf_model_config_path=os.path.join(model_path, "huggingface"),
            target_dir=merged_model_path
        )
        merger = FSDPModelMerger(merge_config)
        merger.merge_and_save()
        merger.cleanup()

    upload_to_huggingface(merged_model_path, repo_name, hf_token, do_merge_config, args.do_upload)