#!/usr/bin/env python3
"""
Script to merge LoRA adapter weights with base model and save as a full model
for inference evaluation.

python merge_lora.py \
    --base_model GSAI-ML/LLaDA-8B-Instruct \
    --adapter_path ./sft_output/llada-s1_loss_vanilla_selection_random/checkpoint-7425 \
    --output_path ./merged_models/llada-s1_loss_vanilla_selection_random/checkpoint-7425
"""

import torch
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel
import argparse
import os

def merge_lora_model(base_model_path, adapter_path, output_path):
    """
    Merge LoRA adapter with base model and save as full model
    """
    print(f"Loading base model from: {base_model_path}")

    # Load base model
    tokenizer = AutoTokenizer.from_pretrained(
        base_model_path,
        trust_remote_code=True,
        use_fast=True
    )

    base_model = AutoModel.from_pretrained(
        base_model_path,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )

    print(f"Loading LoRA adapter from: {adapter_path}")

    # Load LoRA model
    model = PeftModel.from_pretrained(base_model, adapter_path)

    print("Merging LoRA weights with base model...")

    # Merge and unload
    model = model.merge_and_unload()

    print(f"Saving merged model to: {output_path}")

    # Save merged model
    os.makedirs(output_path, exist_ok=True)
    model.save_pretrained(output_path)
    tokenizer.save_pretrained(output_path)

    print("✅ Model merged and saved successfully!")
    print(f"You can now use this path for inference: {output_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model", type=str, default="GSAI-ML/LLaDA-8B-Instruct", help="Base model path")
    parser.add_argument("--adapter_path", type=str, required=True, help="Path to LoRA adapter")
    parser.add_argument("--output_path", type=str, required=True, help="Output path for merged model")

    args = parser.parse_args()

    merge_lora_model(args.base_model, args.adapter_path, args.output_path)