import os

from accelerate.utils import extract_model_from_parallel as unwrap_model
from huggingface_hub import HfApi, HfFolder, create_repo, get_full_repo_name
from transformers import TrainerCallback


class PushLoRAEachEpochCallback(TrainerCallback):
    def __init__(self, repo_id: str, tokenizer=None):
        """
        repo_id can be either bare ("my-model") or fully qualified ("user-or-org/my-model").
        We’ll resolve it to a fully-qualified id on train begin.
        """
        self.repo_id_raw = repo_id
        self.repo_id_full = None
        self.tokenizer = tokenizer

    def on_train_begin(self, args, state, control, **kw):
        # cache tokenizer if Trainer provides it
        if self.tokenizer is None:
            self.tokenizer = kw.get("tokenizer", None)

        # resolve namespace + ensure repo exists

        token = (
            HfFolder.get_token()
        )  # must have called huggingface_hub.login(...) earlier
        self.repo_id_full = get_full_repo_name(self.repo_id_raw, token=token)
        create_repo(self.repo_id_full, private=True, exist_ok=True)

    def on_epoch_end(self, args, state, control, **kw):
        if not state.is_world_process_zero:
            return

        model = kw.get("model", None)
        tok = kw.get("tokenizer", self.tokenizer)

        m = unwrap_model(model)
        if hasattr(m, "module"):
            m = m.module

        ep = int(state.epoch or 0)
        folder_name = f"epoch-{ep}"

        # 1) save locally to .../epoch-X
        out = os.path.join(args.output_dir, folder_name)
        os.makedirs(out, exist_ok=True)
        m.save_pretrained(out, safe_serialization=True)
        if tok is not None:
            tok.save_pretrained(out)

        # 2) upload that local folder into *the same subfolder on the Hub*

        api = HfApi()
        api.upload_folder(
            folder_path=out,  # local folder we just saved
            repo_id=self.repo_id_full,  # fully-qualified "<namespace>/<repo>"
            repo_type="model",
            path_in_repo=folder_name,  # creates/updates epoch-X/ on the repo
            commit_message=f"Add {folder_name}",
        )
