import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)
from peft import PeftModel
import argparse


def merge_models(model_name, peft_model, merged_peft_model_name):

    tokenizer = AutoTokenizer.from_pretrained(peft_model)

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        token=tokenizer,
        torch_dtype=torch.float16,
        device_map="auto",
        offload_folder="offload",
    )

    model = PeftModel.from_pretrained(
        model,
        peft_model,
        torch_dtype=torch.float16,
        device_map="auto",
        offload_folder="offload",
    )

    model = model.merge_and_unload()
    model.save_pretrained(merged_peft_model_name)
    tokenizer.save_pretrained(merged_peft_model_name)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Merge the PEFT model with the Mistral model")
    parser.add_argument("--original_model", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct", help="The name of the model")
    parser.add_argument("--peft_model", type=str, default="2024-04-30_10-46-52", help="The name of the model")
    parser.add_argument("--merged_model_name", type=str, default="minigrid_merged_model", help="The name of the model")
    
    args = parser.parse_args()
    
    merge_models(args.original_model, "./results/models/" + args.peft_model, "./results/models/" + args.merged_model_name)
    
    