import argparse
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
import os


def merge_lora(args):
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu')

    model.save_pretrained(args.save_model_path)
    tokenizer.save_pretrained(args.save_model_path)


# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument("--model-path", type=str, required=True)
#     parser.add_argument("--model-base", type=str, required=False)
#     parser.add_argument("--save-model-path", type=str, required=True)

#     args = parser.parse_args()

#     merge_lora(args)
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--base-folder", type=str, required=True)
    parser.add_argument("--model-base", type=str, required=False)
    parser.add_argument("--model-name", type=str, required=True)
    parser.add_argument("--output-base-folder", type=str, required=True)
    args = parser.parse_args()
    args.model_path = os.path.join(args.base_folder, args.model_name)
    args.save_model_path = os.path.join(args.output_base_folder, args.model_name)
    os.makedirs(args.save_model_path, exist_ok=True)
    merge_lora(args)
