from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import os
import argparse

def merge_lora_adapter(base_model_dir, lora_checkpoint_dir, merged_model_dir):
    print(f"Merging LoRA adapter from {lora_checkpoint_dir} into {base_model_dir}")
    print(f"Output will be saved to {merged_model_dir}")
    
    # Ensure the directory exists by creating all parent directories as needed
    os.makedirs(merged_model_dir, exist_ok=True)
    print(f"Created directory: {merged_model_dir}")
    
    # Check if directory was created successfully
    if not os.path.exists(merged_model_dir):
        raise RuntimeError(f"Failed to create directory: {merged_model_dir}")

    # Load the base model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(base_model_dir)
    tokenizer = AutoTokenizer.from_pretrained(base_model_dir)
    print("Loaded base model and tokenizer")

    # Load the LoRA adapter
    peft_model = PeftModel.from_pretrained(model, lora_checkpoint_dir)
    print("Loaded LoRA model")

    # Merge the LoRA weights into the base model
    merged_model = peft_model.merge_and_unload()
    print("Merged model")
    
    # Save the merged model and tokenizer
    merged_model.save_pretrained(merged_model_dir)
    tokenizer.save_pretrained(merged_model_dir)

    print(f"Merged model saved to {merged_model_dir}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Merge LoRA adapter with base model")
    
    parser.add_argument('--base_model_dir', type=str,
                       default="/is/cluster/fast/rolmedo/models/llama-3.1-8b-instruct/",
                       help="Directory containing the base model")
    
    parser.add_argument('--lora_checkpoint_dir', type=str, required=True,
                       help="Directory containing the LoRA checkpoint")
    
    parser.add_argument('--merged_model_dir', type=str,
                       help="Directory to save the merged model (defaults to lora_dir/merged/)")
    
    args = parser.parse_args()
    
    # If merged_model_dir not provided, use lora_checkpoint_dir/merged/
    merged_model_dir = args.merged_model_dir
    if merged_model_dir is None:
        merged_model_dir = os.path.join(args.lora_checkpoint_dir, "merged")
    
    merge_lora_adapter(args.base_model_dir, args.lora_checkpoint_dir, merged_model_dir)
