import os
import shutil

from accelerate.utils import extract_model_from_parallel as unwrap_model
from huggingface_hub import create_repo, upload_folder


def upload_hf_train_end(trainer, model, train_method, repo_id):
    if trainer.is_world_process_zero():

        model_unwrapped = unwrap_model(trainer.model)
        if hasattr(model_unwrapped, "module"):
            model_unwrapped = model_unwrapped.module

        push_adapters_only = True

        out_dir = os.path.join(trainer.args.output_dir, "hub_export")
        if os.path.exists(out_dir):
            shutil.rmtree(out_dir)
        os.makedirs(out_dir, exist_ok=True)

        repo_id_adapters = f"{repo_id}"  # "msc_unlearn_lora_ga_tofu-adapters"
        repo_id_full = repo_id

        create_repo(
            repo_id_adapters if push_adapters_only else repo_id_full,
            private=True,
            exist_ok=True,
        )
        if push_adapters_only:
            # we need to use this for consolidation (for DS don't need consolidation)
            model_unwrapped.save_pretrained(out_dir, safe_serialization=True)
            model.tokenizer.save_pretrained(out_dir)
            # push the adapters repo
            model_unwrapped.push_to_hub(
                repo_id_adapters,
                commit_message=f"LoRA adapters after training ({train_method})",
                private=True,
            )
        else:
            model_unwrapped.save_pretrained(out_dir, safe_serialization=True)
            model.tokenizer.save_pretrained(out_dir)

            upload_folder(
                folder_path=out_dir,
                repo_id=repo_id_full,
                repo_type="model",
                commit_message=f"Full model after training ({train_method})",
            )
