#!/usr/bin/env python
"""
push_adapter_to_hub.py  (no-clone edition)

Usage
-----
python push_adapter_to_hub.py \
    /path/to/adapter_folder \
    [--base-model-id bigscience/bloom-560m] \
    [--repo-id my-username/bloom-560m-adapter] \
    --hf-token YOUR_TOKEN

Environment fallback: if --hf-token is omitted, $HF_TOKEN is used.
"""

import os
import shutil
import json
from pathlib import Path

from huggingface_hub import (
    HfApi,
    login,
)

# --------------------------------------------------------------------------- #
# Parse command-line arguments 
import argparse

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("adapter_dir",
                        help="Folder with *.safetensors, config, README, …")
    parser.add_argument("--base-model-id", default=None,
                        help="HF model ID to pull tokenizer/processor from")
    parser.add_argument("--repo-id", default=None,
                        help="Destination repo ID, e.g. username/adapter-name")
    parser.add_argument("--hf-token", default=os.getenv("HF_TOKEN"),
                        help="HF access token (or set $HF_TOKEN)")
    parser.add_argument("--private", action="store_true",
                        help="Create the repo as private")
    parser.add_argument("--branch", default="main",
                        help="Branch to push to (default: main)")
    return parser.parse_args()
# --------------------------------------------------------------------------- #

def get_base_model_name_or_path(adapter_dir):
    """
    Extracts the base model name or path from the adapter directory's config file.
    Raises KeyError if not found.
    """
    cfg_path = Path(adapter_dir) / "adapter_config.json"
    if not cfg_path.exists():
        raise FileNotFoundError(f"Adapter config not found at {cfg_path}")
    
    with open(cfg_path, "r") as f:
        cfg = json.load(f)
    
    try:
        return cfg["base_model_name_or_path"]
    except KeyError:
        raise KeyError("'base_model_name_or_path' not found in adapter_config.json")

def update_preprocessor_config(adapter_dir, repo_id):
    """
    Updates the preprocessor_config.json in the adapter directory with the
    processor_class based on the repo_id.
    """
    proc_cfg_path = Path(adapter_dir) / "preprocessor_config.json"
    if not proc_cfg_path.exists():
        return  # No preprocessor config to update

    with open(proc_cfg_path, "r") as f:
        proc_cfg = json.load(f)

    prefix = repo_id.split("/", 1)[-1].lower()
    if prefix.startswith("colvbert"):
        cls = "ColVBertProcessor"
    elif prefix.startswith("bivbert"):
        cls = "BiVBertProcessor"
    elif prefix.startswith("colvllama"):
        cls = "ColVLlamaProcessor"
    elif prefix.startswith("bivllama"):
        cls = "BiVLlamaProcessor"
    else:
        cls = None

    if cls:
        proc_cfg["processor_class"] = cls
        with open(proc_cfg_path, "w") as f:
            json.dump(proc_cfg, f, indent=2)
        print(f"→ Set processor_class to '{cls}' in {proc_cfg_path.name}")
    else:
        print("ℹ️  No specific processor_class set; skipped updating preprocessor_config.json")

# --------------------------------------------------------------------------- #
def main():
    args = parse_args()

    if args.hf_token is None:
        raise RuntimeError("No token provided and $HF_TOKEN is empty.")

    src_dir = Path(args.adapter_dir).expanduser().resolve()
    if not src_dir.exists():
        raise FileNotFoundError(src_dir)

    # --------------------------------------------------------------------- #
    # 0. Derive base-model-id if not given
    if args.base_model_id is None:
        try:
            args.base_model_id = get_base_model_name_or_path(src_dir)
            print(f"ℹ️  Derived base_model_id='{args.base_model_id}' from adapter config")
        except (FileNotFoundError, KeyError) as e:
            raise RuntimeError(f"Failed to derive base model ID: {e}")

    # 1. Derive repo-id if not given
    if args.repo_id is None:
        parent = src_dir.parent.name
        grandparent = src_dir.parent.parent.name
        grandgrandparent = src_dir.parent.parent.parent.name
        col_part = parent.split("_", 1)[0]
        base_model_name = grandparent.split("-", 1)[-1]
        args.repo_id = f"SmolVEncoder/{col_part}-{base_model_name}-{grandgrandparent}"
        print(f"ℹ️  Derived repo_id='{args.repo_id}' from adapter directory structure")

    # --------------------------------------------------------------------- #
    # 2. Log in & make sure the repo exists
    login(args.hf_token)
    api = HfApi()
    api.create_repo(
        repo_id=args.repo_id,
        repo_type="model",
        private=args.private,
        exist_ok=True,
    )
    print(f"✅ Repo ready: https://huggingface.co/{args.repo_id}")

    # --------------------------------------------------------------------- #
    # 3. Prepare the staging directory `<adapter_dir>__complete`
    complete_dir = src_dir.with_name(src_dir.name + "__complete")
    if complete_dir.exists():
        # Start fresh to avoid stale files
        shutil.rmtree(complete_dir)
    complete_dir.mkdir(parents=True)

    # Copy adapter checkpoint / metadata
    for file in src_dir.iterdir():
        if file.is_file():
            shutil.copy2(file, complete_dir)
            print(f"→ Copied {file.name}")

    # --------------------------------------------------------------------- #
    # 4. Pull tokenizer (+ processor if present) from base model
    # print(f"⏬ Pulling tokenizer from {args.base_model_id}")
    # tok = AutoTokenizer.from_pretrained(args.base_model_id, trust_remote_code=True)
    # tok.save_pretrained(complete_dir)

    # try:
    #     proc = AutoProcessor.from_pretrained(args.base_model_id, trust_remote_code=True)
    #     proc.save_pretrained(complete_dir)
    #     print("→ Saved processor files")
    # except (ValueError, OSError):
    #     # Fine for most text-only models
    #     print("ℹ️  No processor found – skipped.")

    # . Set processor_class in preprocessor_config.json if exists
    # update_preprocessor_config(complete_dir, args.repo_id)

    # copy ./processor_config.json to the complete_dir
    shutil.copy2(Path("./processor_config.json"), complete_dir)
    print("→ Copied processor_config.json")

    # --------------------------------------------------------------------- #
    # 5. Upload everything in one shot
    print("☁️  Uploading folder to the Hub …")
    api.upload_folder(
        folder_path=str(complete_dir),
        repo_id=args.repo_id,
        repo_type="model",
        path_in_repo=".",          # root of the repo
        commit_message="Add adapter checkpoint + tokenizer/processor",
        token=args.hf_token,
        revision=args.branch,
    )
    print("🚀 Pushed to the Hub!")

# --------------------------------------------------------------------------- #
if __name__ == "__main__":
    main()
