#!/usr/bin/env python3
"""
Merge a LoRA adapter into a base model and save the result.

Supports loading LoRA from HuggingFace Hub (with subfolder) or local path.
Handles authentication for private repos via HF_TOKEN env var or --token arg.

Usage:
    python merge_lora.py \
        --base_model Qwen/Qwen2.5-7B-Instruct \
        --lora_path anon-neurips26/qwen25_7b_bridges_dsr \
        --lora_subfolder best_checkpoint_bridges_5x5de_test200_intformat_json \
        --output_dir checkpoints/merged_model \
        --token $HF_TOKEN
"""

import os
import sys
import argparse


def merge_lora(
    base_model: str,
    lora_path: str,
    output_dir: str,
    lora_subfolder: str = None,
    token: str = None,
    torch_dtype: str = "float32",
) -> str:
    """
    Merge a LoRA adapter into a base model and save.

    Args:
        base_model: Base model path or HF repo ID
        lora_path: LoRA adapter path or HF repo ID
        output_dir: Directory to save merged model
        lora_subfolder: Subfolder within lora_path (for HF repos with multiple adapters)
        token: HuggingFace API token for private repos
        torch_dtype: Model dtype (default: "float32"). Always use float32 to avoid
                     precision loss — bf16 merge degrades AIME24 accuracy.

    Returns:
        Path to the merged output directory
    """
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from peft import PeftModel

    # Set HF auth via multiple methods to ensure it works across all library versions:
    # 1. Set env var (huggingface_hub reads HF_TOKEN automatically)
    # 2. Call login() (sets global token in huggingface_hub internals)
    # 3. Pass token= kwarg to each call (explicit)
    if token:
        os.environ["HF_TOKEN"] = token
        from huggingface_hub import login
        login(token=token, add_to_git_credential=False)

    print(f"Loading base model: {base_model}")
    model = AutoModelForCausalLM.from_pretrained(
        base_model, torch_dtype=torch_dtype, token=token
    )
    tokenizer = AutoTokenizer.from_pretrained(base_model, token=token)

    print(f"Loading LoRA adapter: {lora_path}", end="")
    if lora_subfolder:
        print(f" / {lora_subfolder}")
    else:
        print()

    # Download adapter to local dir first to bypass peft's unreliable auth.
    # PeftModel.from_pretrained has inconsistent token forwarding across versions.
    local_adapter_path = lora_path
    if "/" in lora_path and not os.path.isdir(lora_path):
        from huggingface_hub import snapshot_download
        print(f"  Downloading adapter from HF Hub...")
        allow_patterns = f"{lora_subfolder}/*" if lora_subfolder else None
        local_adapter_path = snapshot_download(
            repo_id=lora_path,
            token=token,
            allow_patterns=allow_patterns,
        )
        if lora_subfolder:
            local_adapter_path = os.path.join(local_adapter_path, lora_subfolder)
        print(f"  Downloaded to: {local_adapter_path}")

    model = PeftModel.from_pretrained(model, local_adapter_path)

    print("Merging LoRA into base model...")
    model = model.merge_and_unload()

    print(f"Saving merged model to: {output_dir}")
    os.makedirs(output_dir, exist_ok=True)
    # Fix invalid generation configs (e.g., OLMo3 sets temperature/top_p without do_sample=True)
    if hasattr(model, "generation_config"):
        gc = model.generation_config
        if getattr(gc, "temperature", None) is not None and not getattr(gc, "do_sample", False):
            gc.temperature = None
        if getattr(gc, "top_p", None) is not None and not getattr(gc, "do_sample", False):
            gc.top_p = None
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print("Merge complete!")

    return output_dir


def main():
    parser = argparse.ArgumentParser(
        description="Merge LoRA adapter into base model"
    )
    parser.add_argument("--base_model", type=str, required=True,
                       help="Base model path or HF repo ID")
    parser.add_argument("--lora_path", type=str, required=True,
                       help="LoRA adapter path or HF repo ID")
    parser.add_argument("--lora_subfolder", type=str, default=None,
                       help="Subfolder within lora_path")
    parser.add_argument("--output_dir", type=str, required=True,
                       help="Directory to save merged model")
    parser.add_argument("--token", type=str, default=None,
                       help="HuggingFace API token (uses HF_TOKEN env var if not set)")
    parser.add_argument("--torch_dtype", type=str, default="float32",
                       help="Model dtype for merging (default: float32). "
                            "Use float32 to avoid precision loss during LoRA merge. "
                            "bf16 merge degrades AIME24 accuracy (6.7%% vs 10%% with LoRA loading).")

    args = parser.parse_args()

    token = args.token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")

    merge_lora(
        base_model=args.base_model,
        lora_path=args.lora_path,
        output_dir=args.output_dir,
        lora_subfolder=args.lora_subfolder,
        token=token,
        torch_dtype=args.torch_dtype,
    )


if __name__ == "__main__":
    main()
