import torch

from argparse import ArgumentParser
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel


def merge_model(base_model_path, adapter_path, new_model_path):

    base_model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.bfloat16)

    # load tokenizer from finetuned model
    tokenizer = AutoTokenizer.from_pretrained(adapter_path)

    # resize the base model based on the new tokenizer
    tokenizer.save_pretrained(new_model_path)
    base_model.resize_token_embeddings(len(tokenizer))

    lora_model = PeftModel.from_pretrained(base_model, adapter_path)

    merged_model = lora_model.merge_and_unload()
    merged_model.save_pretrained(new_model_path)

if __name__ == "__main__":
    parser= ArgumentParser()
    parser.add_argument("--base_model_path", type=str, help="Please enter the path/to/base/model")
    parser.add_argument("--adapter_path", type=str, help="Please enter the path/to/finetuned/adapter")
    parser.add_argument("--new_model_path", type=str, help="Please enter the path/to/new/model")

    args = parser.parse_args()
    merge_model(args.base_model_path, args.adapter_path, args.new_model_path)