# Jinja2 template for merging LoRA models
# Used by benchmark.py to generate a merging script

template: |-
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from peft import PeftModel
    import os
    import shutil

    base_model_path = "{{ base_model_path }}"
    adapter_path = "{{ adapter_path }}"
    output_path = "{{ output_path }}"

    print(f"Loading base model from {base_model_path}...")
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
        local_files_only=True
    )

    print(f"Loading LoRA adapter from {adapter_path}...")
    model = PeftModel.from_pretrained(base_model, adapter_path, local_files_only=True)

    print(f"Loading tokenizer from {adapter_path}...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True, local_files_only=True)
    except:
        print("Tokenizer not found in adapter, loading from base model...")
        tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True, local_files_only=True)

    print("Merging model...")
    model = model.merge_and_unload()

    if os.path.exists(output_path):
        print(f"Removing existing output path: {output_path}")
        shutil.rmtree(output_path)

    print(f"Saving merged model to {output_path}...")
    model.save_pretrained(output_path)
    tokenizer.save_pretrained(output_path)
    print("Merge Done.")
