import json
import os
import argparse

def main():
    parser = argparse.ArgumentParser(description="Merge LoRA adapter with base model")
    parser.add_argument("--model_name_or_path", type=str, required=True,
                        help="Path or name of the base model (e.g., google/gemma-3-12b-it)")
    parser.add_argument("--adapter_name_or_path", type=str, required=True,
                        help="Name of the adapter directory (e.g., llama3-8b_lora_sft)")
    parser.add_argument("--template", type=str, required=True,
                        help="Template name that matches training template (e.g., gemma3, llama3)")
    parser.add_argument("--export_size", type=int, default=5,
                        help="File shard size in GB (default: 5)")
    parser.add_argument("--export_device", type=str, default="cpu", choices=["cpu", "auto"],
                        help="Export device, can be 'cpu' or 'auto' (default: cpu)")
    
    args = parser.parse_args()
    
    # Build paths using the provided arguments
    adapter_name_or_path = args.adapter_name_or_path
    export_dir = f"{adapter_name_or_path}-merge"
    config_path = f"{adapter_name_or_path}-merge/merge.json"

    # Create configuration dictionary
    config = dict(
        model_name_or_path=args.model_name_or_path,
        adapter_name_or_path=adapter_name_or_path,
        template=args.template,
        finetuning_type="lora",
        export_dir=export_dir,
        export_size=args.export_size,
        export_device=args.export_device,
    )
    
    # Ensure output directory exists
    os.makedirs(os.path.dirname(config_path), exist_ok=True)
    
    # Save config to JSON file
    with open(config_path, "w", encoding="utf-8") as f:
        json.dump(config, f, indent=2)
    
    print(f"Configuration saved to: {config_path}")
    print(f"Merging adapter: {adapter_name_or_path}")
    print(f"Output directory: {export_dir}")
    
    # Run merge command
    # os.chdir("/content/LLaMA-Factory")  # Change directory to your LLaMA-Factory path
    os.system(f"llamafactory-cli export {config_path}")

if __name__ == "__main__":
    main()
